blob: 76def82729b9f57d069e5d1798afcdf0074006e1 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import unittest
import logging
import os
from collections import OrderedDict
from decimal import Decimal
from typing import Iterator, Tuple, Any
from pyspark.loose_version import LooseVersion
from pyspark.sql import Row, functions as sf
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from pyspark.sql.types import (
IntegerType,
DoubleType,
ArrayType,
BinaryType,
ByteType,
LongType,
DecimalType,
ShortType,
FloatType,
StringType,
BooleanType,
StructType,
StructField,
NullType,
MapType,
YearMonthIntervalType,
)
from pyspark.errors import PythonException, PySparkTypeError, PySparkValueError
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.testing.utils import assertDataFrameEqual
from pyspark.util import is_remote_only
if have_pyarrow and have_pandas:
import pandas as pd
from pandas.testing import assert_frame_equal
@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
class ApplyInPandasTestsMixin:
@property
def data(self):
return (
self.spark.range(10)
.withColumn("vs", sf.array([sf.lit(i) for i in range(20, 30)]))
.withColumn("v", sf.explode(sf.col("vs")))
.drop("vs")
)
def test_supported_types(self):
values = [
1,
2,
3,
4,
5,
1.1,
2.2,
Decimal(1.123),
[1, 2, 2],
True,
"hello",
bytearray([0x01, 0x02]),
None,
]
output_fields = [
("id", IntegerType()),
("byte", ByteType()),
("short", ShortType()),
("int", IntegerType()),
("long", LongType()),
("float", FloatType()),
("double", DoubleType()),
("decim", DecimalType(10, 3)),
("array", ArrayType(IntegerType())),
("bool", BooleanType()),
("str", StringType()),
("bin", BinaryType()),
("null", NullType()),
]
output_schema = StructType([StructField(*x) for x in output_fields])
df = self.spark.createDataFrame([values], schema=output_schema)
# Different forms of group map pandas UDF, results of these are the same
udf1 = pandas_udf(
lambda pdf: pdf.assign(
byte=pdf.byte * 2,
short=pdf.short * 2,
int=pdf.int * 2,
long=pdf.long * 2,
float=pdf.float * 2,
double=pdf.double * 2,
decim=pdf.decim * 2,
bool=False if pdf.bool else True,
str=pdf.str + "there",
array=pdf.array,
bin=pdf.bin,
null=pdf.null,
),
output_schema,
PandasUDFType.GROUPED_MAP,
)
udf2 = pandas_udf(
lambda _, pdf: pdf.assign(
byte=pdf.byte * 2,
short=pdf.short * 2,
int=pdf.int * 2,
long=pdf.long * 2,
float=pdf.float * 2,
double=pdf.double * 2,
decim=pdf.decim * 2,
bool=False if pdf.bool else True,
str=pdf.str + "there",
array=pdf.array,
bin=pdf.bin,
null=pdf.null,
),
output_schema,
PandasUDFType.GROUPED_MAP,
)
udf3 = pandas_udf(
lambda key, pdf: pdf.assign(
id=key[0],
byte=pdf.byte * 2,
short=pdf.short * 2,
int=pdf.int * 2,
long=pdf.long * 2,
float=pdf.float * 2,
double=pdf.double * 2,
decim=pdf.decim * 2,
bool=False if pdf.bool else True,
str=pdf.str + "there",
array=pdf.array,
bin=pdf.bin,
null=pdf.null,
),
output_schema,
PandasUDFType.GROUPED_MAP,
)
result1 = df.groupby("id").apply(udf1).sort("id").toPandas()
expected1 = df.toPandas().groupby("id").apply(udf1.func).reset_index(drop=True)
result2 = df.groupby("id").apply(udf2).sort("id").toPandas()
expected2 = expected1
result3 = df.groupby("id").apply(udf3).sort("id").toPandas()
expected3 = expected1
assert_frame_equal(expected1, result1)
assert_frame_equal(expected2, result2)
assert_frame_equal(expected3, result3)
def test_array_type_correct(self):
df = self.data.withColumn("arr", sf.array(sf.col("id"))).repartition(1, "id")
output_schema = StructType(
[
StructField("id", LongType()),
StructField("v", IntegerType()),
StructField("arr", ArrayType(LongType())),
]
)
udf = pandas_udf(lambda pdf: pdf, output_schema, PandasUDFType.GROUPED_MAP)
result = df.groupby("id").apply(udf).sort("id").toPandas()
expected = df.toPandas().groupby("id").apply(udf.func).reset_index(drop=True)
assert_frame_equal(expected, result)
def test_register_grouped_map_udf(self):
with self.quiet(), self.temp_func("foo_udf"):
foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
with self.assertRaises(PySparkTypeError) as pe:
self.spark.catalog.registerFunction("foo_udf", foo_udf)
self.check_error(
exception=pe.exception,
errorClass="INVALID_UDF_EVAL_TYPE",
messageParameters={
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_ARROW_UDF, "
"SQL_SCALAR_PANDAS_ITER_UDF, SQL_SCALAR_ARROW_ITER_UDF, "
"SQL_GROUPED_AGG_PANDAS_UDF, SQL_GROUPED_AGG_ARROW_UDF, "
"SQL_GROUPED_AGG_PANDAS_ITER_UDF or SQL_GROUPED_AGG_ARROW_ITER_UDF"
},
)
def test_decorator(self):
df = self.data
@pandas_udf("id long, v int, v1 double, v2 long", PandasUDFType.GROUPED_MAP)
def foo(pdf):
return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id)
result = df.groupby("id").apply(foo).sort("id").toPandas()
expected = df.toPandas().groupby("id").apply(foo.func).reset_index(drop=True)
assert_frame_equal(expected, result)
def test_coerce(self):
df = self.data
foo = pandas_udf(lambda pdf: pdf, "id long, v double", PandasUDFType.GROUPED_MAP)
result = df.groupby("id").apply(foo).sort("id").toPandas()
expected = df.toPandas().groupby("id").apply(foo.func).reset_index(drop=True)
expected = expected.assign(v=expected.v.astype("float64"))
assert_frame_equal(expected, result)
def test_complex_groupby(self):
df = self.data
@pandas_udf("id long, v int, norm double", PandasUDFType.GROUPED_MAP)
def normalize(pdf):
v = pdf.v
return pdf.assign(norm=(v - v.mean()) / v.std())
result = df.groupby(sf.col("id") % 2 == 0).apply(normalize).sort("id", "v").toPandas()
pdf = df.toPandas()
expected = pdf.groupby(pdf["id"] % 2 == 0, as_index=False).apply(normalize.func)
expected = expected.sort_values(["id", "v"]).reset_index(drop=True)
expected = expected.assign(norm=expected.norm.astype("float64"))
assert_frame_equal(expected, result)
def test_empty_groupby(self):
df = self.data
@pandas_udf("id long, v int, norm double", PandasUDFType.GROUPED_MAP)
def normalize(pdf):
v = pdf.v
return pdf.assign(norm=(v - v.mean()) / v.std())
result = df.groupby().apply(normalize).sort("id", "v").toPandas()
pdf = df.toPandas()
expected = normalize.func(pdf)
expected = expected.sort_values(["id", "v"]).reset_index(drop=True)
expected = expected.assign(norm=expected.norm.astype("float64"))
assert_frame_equal(expected, result)
def test_apply_in_pandas_not_returning_pandas_dataframe(self):
with self.quiet():
self.check_apply_in_pandas_not_returning_pandas_dataframe()
def check_apply_in_pandas_not_returning_pandas_dataframe(self):
with self.assertRaisesRegex(
PythonException,
"Return type of the user-defined function should be pandas.DataFrame, but is tuple.",
):
self._test_apply_in_pandas(lambda key, pdf: key)
def test_apply_in_pandas_returning_column_names(self):
self._test_apply_in_pandas(
lambda key, pdf: pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"])
)
def test_apply_in_pandas_returning_no_column_names(self):
self._test_apply_in_pandas(lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(),)]))
def test_apply_in_pandas_returning_column_names_sometimes(self):
def stats(key, pdf):
if key[0] % 2:
return pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"])
else:
return pd.DataFrame([key + (pdf.v.mean(),)])
self._test_apply_in_pandas(stats)
def test_apply_in_pandas_returning_wrong_column_names(self):
with self.quiet():
self.check_apply_in_pandas_returning_wrong_column_names()
def check_apply_in_pandas_returning_wrong_column_names(self):
with self.assertRaisesRegex(
PythonException,
"Column names of the returned pandas.DataFrame do not match specified schema. "
"Missing: mean. Unexpected: median, std.",
):
self._test_apply_in_pandas(
lambda key, pdf: pd.DataFrame(
[key + (pdf.v.median(), pdf.v.std())], columns=["id", "median", "std"]
)
)
def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
with self.quiet():
self.check_apply_in_pandas_returning_no_column_names_and_wrong_amount()
def check_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
with self.assertRaisesRegex(
PythonException,
"Number of columns of the returned pandas.DataFrame doesn't match "
"specified schema. Expected: 2 Actual: 3",
):
self._test_apply_in_pandas(
lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(), pdf.v.std())])
)
@unittest.skipIf(
os.environ.get("SPARK_SKIP_CONNECT_COMPAT_TESTS") == "1", "SPARK-54482: To be reenabled"
)
def test_apply_in_pandas_returning_empty_dataframe(self):
self._test_apply_in_pandas_returning_empty_dataframe(pd.DataFrame())
@unittest.skipIf(
os.environ.get("SPARK_SKIP_CONNECT_COMPAT_TESTS") == "1", "SPARK-54482: To be reenabled"
)
def test_apply_in_pandas_returning_incompatible_type(self):
with self.quiet():
self.check_apply_in_pandas_returning_incompatible_type()
def check_apply_in_pandas_returning_incompatible_type(self):
for safely in [True, False]:
with self.subTest(convertToArrowArraySafely=safely), self.sql_conf(
{"spark.sql.execution.pandas.convertToArrowArraySafely": safely}
):
# sometimes we see ValueErrors
with self.subTest(convert="string to double"):
pandas_type_name = "object" if LooseVersion(pd.__version__) < "3.0.0" else "str"
expected = (
rf"ValueError: Exception thrown when converting pandas.Series \({pandas_type_name}\) "
r"with name 'mean' to Arrow Array \(double\)."
)
if safely:
expected = expected + (
" It can be caused by overflows or other "
"unsafe conversions warned by Arrow. Arrow safe type check "
"can be disabled by using SQL config "
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
with self.assertRaisesRegex(PythonException, expected):
self._test_apply_in_pandas(
lambda key, pdf: pd.DataFrame([key + ("test_string",)]),
output_schema="id long, mean double",
)
# sometimes we see TypeErrors
with self.subTest(convert="double to string"):
with self.assertRaisesRegex(
PythonException,
r"TypeError: Exception thrown when converting pandas.Series \(float64\) "
r"with name 'mean' to Arrow Array \(string\).",
):
self._test_apply_in_pandas(
lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(),)]),
output_schema="id long, mean string",
)
def test_apply_in_pandas_int_to_decimal_coercion(self):
def int_to_decimal_func(key, pdf):
return pd.DataFrame([{"id": key[0], "decimal_result": 12345}])
with self.sql_conf(
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": True}
):
result = (
self.data.groupby("id")
.applyInPandas(int_to_decimal_func, schema="id long, decimal_result decimal(10,2)")
.collect()
)
self.assertTrue(len(result) > 0)
for row in result:
self.assertEqual(row.decimal_result, 12345.00)
with self.sql_conf(
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": False}
):
with self.assertRaisesRegex(
PythonException, "Exception thrown when converting pandas.Series"
):
(
self.data.groupby("id")
.applyInPandas(
int_to_decimal_func, schema="id long, decimal_result decimal(10,2)"
)
.collect()
)
def test_datatype_string(self):
df = self.data
foo_udf = pandas_udf(
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
"id long, v int, v1 double, v2 long",
PandasUDFType.GROUPED_MAP,
)
result = df.groupby("id").apply(foo_udf).sort("id").toPandas()
expected = df.toPandas().groupby("id").apply(foo_udf.func).reset_index(drop=True)
assert_frame_equal(expected, result)
def test_wrong_return_type(self):
with self.quiet():
self.check_wrong_return_type()
def check_wrong_return_type(self):
with self.assertRaisesRegex(
NotImplementedError,
"Invalid return type.*grouped map Pandas UDF.*ArrayType.*YearMonthIntervalType",
):
pandas_udf(
lambda pdf: pdf,
StructType().add("id", LongType()).add("v", ArrayType(YearMonthIntervalType())),
PandasUDFType.GROUPED_MAP,
)
def test_wrong_args(self):
with self.quiet():
self.check_wrong_args()
def check_wrong_args(self):
df = self.data
with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
df.groupby("id").apply(lambda x: x)
with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
df.groupby("id").apply(udf(lambda x: x, DoubleType()))
with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
df.groupby("id").apply(sf.sum(df.v))
with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
df.groupby("id").apply(df.v + 1)
with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
df.groupby("id").apply(pandas_udf(lambda x, y: x, DoubleType()))
with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
df.groupby("id").apply(pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
df.groupby("id").apply(
pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))
)
def test_wrong_args_in_apply_func(self):
df1 = self.spark.range(11)
df2 = self.spark.range(22)
with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
df1.groupby("id").applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())]))
with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
df1.groupby("id").applyInArrow(lambda: 1, StructType([StructField("d", DoubleType())]))
with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
lambda: 1, StructType([StructField("d", DoubleType())])
)
with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
df1.groupby("id").cogroup(df2.groupby("id")).applyInArrow(
lambda: 1, StructType([StructField("d", DoubleType())])
)
def test_unsupported_types(self):
with self.quiet():
self.check_unsupported_types()
def check_unsupported_types(self):
common_err_msg = "Invalid return type.*grouped map Pandas UDF.*"
unsupported_types = [
StructField("array_struct", ArrayType(YearMonthIntervalType())),
StructField("map", MapType(StringType(), YearMonthIntervalType())),
]
for unsupported_type in unsupported_types:
with self.subTest(unsupported_type=unsupported_type.name):
schema = StructType([StructField("id", LongType(), True), unsupported_type])
with self.assertRaisesRegex(NotImplementedError, common_err_msg):
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
# Regression test for SPARK-23314
def test_timestamp_dst(self):
# Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
dt = [
datetime.datetime(2015, 11, 1, 0, 30),
datetime.datetime(2015, 11, 1, 1, 30),
datetime.datetime(2015, 11, 1, 2, 30),
]
df = self.spark.createDataFrame(dt, "timestamp").toDF("time")
foo_udf = pandas_udf(lambda pdf: pdf, "time timestamp", PandasUDFType.GROUPED_MAP)
result = df.groupby("time").apply(foo_udf).sort("time")
assert_frame_equal(df.toPandas(), result.toPandas())
def test_udf_with_key(self):
import numpy as np
df = self.data
pdf = df.toPandas()
def foo1(key, pdf):
assert type(key) == tuple
assert type(key[0]) == np.int64
return pdf.assign(
v1=key[0], v2=pdf.v * key[0], v3=pdf.v * pdf.id, v4=pdf.v * pdf.id.mean()
)
def foo2(key, pdf):
assert type(key) == tuple
assert type(key[0]) == np.int64
assert type(key[1]) == np.int32
return pdf.assign(v1=key[0], v2=key[1], v3=pdf.v * key[0], v4=pdf.v + key[1])
def foo3(key, pdf):
assert type(key) == tuple
assert len(key) == 0
return pdf.assign(v1=pdf.v * pdf.id)
# v2 is int because numpy.int64 * pd.Series<int32> results in pd.Series<int32>
# v3 is long because pd.Series<int64> * pd.Series<int32> results in pd.Series<int64>
udf1 = pandas_udf(
foo1, "id long, v int, v1 long, v2 int, v3 long, v4 double", PandasUDFType.GROUPED_MAP
)
udf2 = pandas_udf(
foo2, "id long, v int, v1 long, v2 int, v3 int, v4 int", PandasUDFType.GROUPED_MAP
)
udf3 = pandas_udf(foo3, "id long, v int, v1 long", PandasUDFType.GROUPED_MAP)
# Test groupby column
result1 = df.groupby("id").apply(udf1).sort("id", "v").toPandas()
expected1 = (
pdf.groupby("id", as_index=False)
.apply(lambda x: udf1.func((x.id.iloc[0],), x))
.sort_values(["id", "v"])
.reset_index(drop=True)
)
assert_frame_equal(expected1, result1)
# Test groupby expression
result2 = df.groupby(df.id % 2).apply(udf1).sort("id", "v").toPandas()
expected2 = (
pdf.groupby(pdf.id % 2, as_index=False)
.apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))
.sort_values(["id", "v"])
.reset_index(drop=True)
)
assert_frame_equal(expected2, result2)
# Test complex groupby
result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort("id", "v").toPandas()
expected3 = (
pdf.groupby([pdf.id, pdf.v % 2], as_index=False)
.apply(
lambda x: udf2.func(
(
x.id.iloc[0],
(x.v % 2).iloc[0],
),
x,
)
)
.sort_values(["id", "v"])
.reset_index(drop=True)
)
assert_frame_equal(expected3, result3)
# Test empty groupby
result4 = df.groupby().apply(udf3).sort("id", "v").toPandas()
expected4 = udf3.func((), pdf)
assert_frame_equal(expected4, result4)
def test_column_order(self):
with self.quiet():
self.check_column_order()
def check_column_order(self):
# Helper function to set column names from a list
def rename_pdf(pdf, names):
pdf.rename(
columns={old: new for old, new in zip(pd_result.columns, names)}, inplace=True
)
df = self.data
grouped_df = df.groupby("id")
grouped_pdf = df.toPandas().groupby("id", as_index=False)
# Function returns a pdf with required column names, but order could be arbitrary using dict
def change_col_order(pdf):
# Constructing a DataFrame from a dict should result in the same order,
# but use OrderedDict to ensure the pdf column order is different than schema
return pd.DataFrame.from_dict(
OrderedDict([("id", pdf.id), ("u", pdf.v * 2), ("v", pdf.v)])
)
ordered_udf = pandas_udf(
change_col_order, "id long, v int, u int", PandasUDFType.GROUPED_MAP
)
# The UDF result should assign columns by name from the pdf
result = grouped_df.apply(ordered_udf).sort("id", "v").select("id", "u", "v").toPandas()
pd_result = grouped_pdf.apply(change_col_order)
expected = pd_result.sort_values(["id", "v"]).reset_index(drop=True)
assert_frame_equal(expected, result)
# Function returns a pdf with positional columns, indexed by range
def range_col_order(pdf):
# Create a DataFrame with positional columns, fix types to long
return pd.DataFrame(list(zip(pdf.id, pdf.v * 3, pdf.v)), dtype="int64")
range_udf = pandas_udf(
range_col_order, "id long, u long, v long", PandasUDFType.GROUPED_MAP
)
# The UDF result uses positional columns from the pdf
result = grouped_df.apply(range_udf).sort("id", "v").select("id", "u", "v").toPandas()
pd_result = grouped_pdf.apply(range_col_order)
rename_pdf(pd_result, ["id", "u", "v"])
expected = pd_result.sort_values(["id", "v"]).reset_index(drop=True)
assert_frame_equal(expected, result)
# Function returns a pdf with columns indexed with integers
def int_index(pdf):
return pd.DataFrame(OrderedDict([(0, pdf.id), (1, pdf.v * 4), (2, pdf.v)]))
int_index_udf = pandas_udf(int_index, "id long, u int, v int", PandasUDFType.GROUPED_MAP)
# The UDF result should assign columns by position of integer index
result = grouped_df.apply(int_index_udf).sort("id", "v").select("id", "u", "v").toPandas()
pd_result = grouped_pdf.apply(int_index)
rename_pdf(pd_result, ["id", "u", "v"])
expected = pd_result.sort_values(["id", "v"]).reset_index(drop=True)
assert_frame_equal(expected, result)
@pandas_udf("id long, v int", PandasUDFType.GROUPED_MAP)
def column_name_typo(pdf):
return pd.DataFrame({"iid": pdf.id, "v": pdf.v})
@pandas_udf("id long, v decimal", PandasUDFType.GROUPED_MAP)
def invalid_positional_types(pdf):
return pd.DataFrame([(1, datetime.date(2020, 10, 5))])
with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
with self.assertRaisesRegex(
PythonException,
"Column names of the returned pandas.DataFrame do not match "
"specified schema. Missing: id. Unexpected: iid.",
):
grouped_df.apply(column_name_typo).collect()
with self.assertRaisesRegex(Exception, "[D|d]ecimal.*got.*date"):
grouped_df.apply(invalid_positional_types).collect()
def test_positional_assignment_conf(self):
with self.sql_conf(
{"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}
):
@pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP)
def foo(_):
return pd.DataFrame([("hi", 1)], columns=["x", "y"])
df = self.data
result = df.groupBy("id").apply(foo).select("a", "b").collect()
for r in result:
self.assertEqual(r.a, "hi")
self.assertEqual(r.b, 1)
def test_self_join_with_pandas(self):
@pandas_udf("key long, col string", PandasUDFType.GROUPED_MAP)
def dummy_pandas_udf(df):
return df[["key", "col"]]
df = self.spark.createDataFrame(
[Row(key=1, col="A"), Row(key=1, col="B"), Row(key=2, col="C")]
)
df_with_pandas = df.groupBy("key").apply(dummy_pandas_udf)
# this was throwing an AnalysisException before SPARK-24208
res = df_with_pandas.alias("temp0").join(
df_with_pandas.alias("temp1"), sf.col("temp0.key") == sf.col("temp1.key")
)
self.assertEqual(res.count(), 5)
def test_mixed_scalar_udfs_followed_by_groupby_apply(self):
df = self.spark.range(0, 10).toDF("v1")
df = df.withColumn("v2", udf(lambda x: x + 1, "int")(df["v1"])).withColumn(
"v3", pandas_udf(lambda x: x + 2, "int")(df["v1"])
)
result = df.groupby().apply(
pandas_udf(
lambda x: pd.DataFrame([x.sum().sum()]), "sum int", PandasUDFType.GROUPED_MAP
)
)
self.assertEqual(result.collect()[0]["sum"], 165)
def test_grouped_with_empty_partition(self):
data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
expected = [Row(id=1, x=5), Row(id=1, x=5), Row(id=2, x=4)]
num_parts = len(data) + 1
df = self.spark.createDataFrame(data).repartition(num_parts)
f = pandas_udf(
lambda pdf: pdf.assign(x=pdf["x"].sum()), "id long, x int", PandasUDFType.GROUPED_MAP
)
result = df.groupBy("id").apply(f).sort("id").collect()
self.assertEqual(result, expected)
def test_grouped_over_window(self):
data = [
(0, 1, "2018-03-10T00:00:00+00:00", [0]),
(1, 2, "2018-03-11T00:00:00+00:00", [0]),
(2, 2, "2018-03-12T00:00:00+00:00", [0]),
(3, 3, "2018-03-15T00:00:00+00:00", [0]),
(4, 3, "2018-03-16T00:00:00+00:00", [0]),
(5, 3, "2018-03-17T00:00:00+00:00", [0]),
(6, 3, "2018-03-21T00:00:00+00:00", [0]),
]
expected = {0: [0], 1: [1, 2], 2: [1, 2], 3: [3, 4, 5], 4: [3, 4, 5], 5: [3, 4, 5], 6: [6]}
df = self.spark.createDataFrame(data, ["id", "group", "ts", "result"])
df = df.select(
sf.col("id"), sf.col("group"), sf.col("ts").cast("timestamp"), sf.col("result")
)
def f(pdf):
# Assign each result element the ids of the windowed group
pdf["result"] = [pdf["id"]] * len(pdf)
return pdf
result = (
df.groupby("group", sf.window("ts", "5 days"))
.applyInPandas(f, df.schema)
.select("id", "result")
.orderBy("id")
.collect()
)
self.assertListEqual([Row(id=key, result=val) for key, val in expected.items()], result)
def test_grouped_over_window_with_key(self):
data = [
(0, 1, "2018-03-10T00:00:00+00:00", [0]),
(1, 2, "2018-03-11T00:00:00+00:00", [0]),
(2, 2, "2018-03-12T00:00:00+00:00", [0]),
(3, 3, "2018-03-15T00:00:00+00:00", [0]),
(4, 3, "2018-03-16T00:00:00+00:00", [0]),
(5, 3, "2018-03-17T00:00:00+00:00", [0]),
(6, 3, "2018-03-21T00:00:00+00:00", [0]),
]
timezone = self.spark.conf.get("spark.sql.session.timeZone")
expected_window = [
{
key: (
pd.Timestamp(ts)
.tz_localize(datetime.timezone.utc)
.tz_convert(timezone)
.tz_localize(None)
)
for key, ts in w.items()
}
for w in [
{
"start": datetime.datetime(2018, 3, 10, 0, 0),
"end": datetime.datetime(2018, 3, 15, 0, 0),
},
{
"start": datetime.datetime(2018, 3, 15, 0, 0),
"end": datetime.datetime(2018, 3, 20, 0, 0),
},
{
"start": datetime.datetime(2018, 3, 20, 0, 0),
"end": datetime.datetime(2018, 3, 25, 0, 0),
},
]
]
expected_key = {
0: (1, expected_window[0]),
1: (2, expected_window[0]),
2: (2, expected_window[0]),
3: (3, expected_window[1]),
4: (3, expected_window[1]),
5: (3, expected_window[1]),
6: (3, expected_window[2]),
}
# id -> array of group with len of num records in window
expected = {0: [1], 1: [2, 2], 2: [2, 2], 3: [3, 3, 3], 4: [3, 3, 3], 5: [3, 3, 3], 6: [3]}
df = self.spark.createDataFrame(data, ["id", "group", "ts", "result"])
df = df.select(
sf.col("id"), sf.col("group"), sf.col("ts").cast("timestamp"), sf.col("result")
)
def f(key, pdf):
group = key[0]
window_range = key[1]
# Make sure the key with group and window values are correct
for _, i in pdf.id.items():
assert expected_key[i][0] == group, "{} != {}".format(expected_key[i][0], group)
assert expected_key[i][1] == window_range, "{} != {}".format(
expected_key[i][1], window_range
)
return pdf.assign(result=[[group] * len(pdf)] * len(pdf))
result = (
df.groupby("group", sf.window("ts", "5 days"))
.applyInPandas(f, df.schema)
.select("id", "result")
.orderBy("id")
.collect()
)
self.assertListEqual([Row(id=key, result=val) for key, val in expected.items()], result)
def test_case_insensitive_grouping_column(self):
# SPARK-31915: case-insensitive grouping column should work.
def my_pandas_udf(pdf):
return pdf.assign(score=0.5)
df = self.spark.createDataFrame([[1, 1]], ["column", "score"])
row = (
df.groupby("COLUMN")
.applyInPandas(my_pandas_udf, schema="column integer, score float")
.first()
)
self.assertEqual(row.asDict(), Row(column=1, score=0.5).asDict())
def _test_apply_in_pandas(self, f, output_schema="id long, mean double"):
df = self.data
result = (
df.groupby("id").applyInPandas(f, schema=output_schema).sort("id", "mean").toPandas()
)
expected = df.select("id").distinct().withColumn("mean", sf.lit(24.5)).toPandas()
assert_frame_equal(expected, result)
def _test_apply_in_pandas_returning_empty_dataframe(self, empty_df):
"""Tests some returned DataFrames are empty."""
df = self.data
def stats(key, pdf):
if key[0] % 2 == 0:
return pd.DataFrame([key + (pdf.v.mean(),)])
return empty_df
result = (
df.groupby("id")
.applyInPandas(stats, schema="id long, mean double")
.sort("id", "mean")
.collect()
)
actual_ids = {row[0] for row in result}
expected_ids = {row[0] for row in self.data.collect() if row[0] % 2 == 0}
self.assertSetEqual(expected_ids, actual_ids)
self.assertEqual(len(expected_ids), len(result))
for row in result:
self.assertEqual(24.5, row[1])
def _test_apply_in_pandas_returning_empty_dataframe_error(self, empty_df, error):
with self.quiet():
with self.assertRaisesRegex(PythonException, error):
self._test_apply_in_pandas_returning_empty_dataframe(empty_df)
def test_arrow_cast_enabled_numeric_to_decimal(self):
import numpy as np
columns = [
"int8",
"int16",
"int32",
"uint8",
"uint16",
"uint32",
"float64",
]
pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in columns})
df = self.spark.range(2).repartition(1)
for column in columns:
with self.subTest(column=column):
v = pdf[column].iloc[:1]
schema_str = "id long, value decimal(10,0)"
@pandas_udf(schema_str, PandasUDFType.GROUPED_MAP)
def test(pdf):
return pdf.assign(**{"value": v})
row = df.groupby("id").apply(test).first()
res = row[1]
self.assertEqual(res, Decimal("1"))
def test_arrow_cast_enabled_str_to_numeric(self):
df = self.spark.range(2).repartition(1)
types = ["int", "long", "float", "double"]
for type_str in types:
with self.subTest(type=type_str):
schema_str = "id long, value " + type_str
@pandas_udf(schema_str, PandasUDFType.GROUPED_MAP)
def test(pdf):
return pdf.assign(value=pd.Series(["123"]))
row = df.groupby("id").apply(test).first()
self.assertEqual(row[1], 123)
def test_arrow_batch_slicing(self):
n = 100000
df = self.spark.range(n).select((sf.col("id") % 2).alias("key"), sf.col("id").alias("v"))
cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
df = df.withColumns(cols)
def min_max_v(pdf):
assert len(pdf) == n / 2, len(pdf)
return pd.DataFrame(
{
"key": [pdf.key.iloc[0]],
"min": [pdf.v.min()],
"max": [pdf.v.max()],
}
)
expected = (
df.groupby("key").agg(sf.min("v").alias("min"), sf.max("v").alias("max")).sort("key")
).collect()
for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 4096), (1000, 4096)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
"spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords,
"spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
}
):
result = (
df.groupBy("key")
.applyInPandas(min_max_v, "key long, min long, max long")
.sort("key")
).collect()
self.assertEqual(expected, result)
def test_negative_and_zero_batch_size(self):
for batch_size in [0, -1]:
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
ApplyInPandasTestsMixin.test_complex_groupby(self)
@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_apply_in_pandas_with_logging(self):
import pandas as pd
def func_with_logging(pdf):
assert isinstance(pdf, pd.DataFrame)
logger = logging.getLogger("test_pandas_grouped_map")
logger.warning(
f"pandas grouped map: {dict(id=list(pdf['id']), value=list(pdf['value']))}"
)
return pdf
df = self.spark.range(9).withColumn("value", sf.col("id") * 10)
grouped_df = df.groupBy((sf.col("id") % 2).cast("int"))
with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}):
assertDataFrameEqual(
grouped_df.applyInPandas(func_with_logging, "id long, value long"),
df,
)
logs = self.spark.tvf.python_worker_logs()
assertDataFrameEqual(
logs.select("level", "msg", "context", "logger"),
[
Row(
level="WARNING",
msg=f"pandas grouped map: {dict(id=lst, value=[v*10 for v in lst])}",
context={"func_name": func_with_logging.__name__},
logger="test_pandas_grouped_map",
)
for lst in [[0, 2, 4, 6, 8], [1, 3, 5, 7]]
],
)
def test_apply_in_pandas_iterator_basic(self):
df = self.spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
)
def sum_func(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
total = 0
for batch in batches:
total += batch["v"].sum()
yield pd.DataFrame({"v": [total]})
result = df.groupby("id").applyInPandas(sum_func, schema="v double").orderBy("v").collect()
self.assertEqual(len(result), 2)
self.assertEqual(result[0][0], 3.0)
self.assertEqual(result[1][0], 18.0)
def test_apply_in_pandas_iterator_with_keys(self):
df = self.spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
)
def sum_func(
key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
) -> Iterator[pd.DataFrame]:
total = 0
for batch in batches:
total += batch["v"].sum()
yield pd.DataFrame({"id": [key[0]], "v": [total]})
result = (
df.groupby("id")
.applyInPandas(sum_func, schema="id long, v double")
.orderBy("id")
.collect()
)
self.assertEqual(len(result), 2)
self.assertEqual(result[0][0], 1)
self.assertEqual(result[0][1], 3.0)
self.assertEqual(result[1][0], 2)
self.assertEqual(result[1][1], 18.0)
def test_apply_in_pandas_iterator_batch_slicing(self):
df = self.spark.range(100000).select(
(sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
)
cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
df = df.withColumns(cols)
def min_max_v(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
# Collect all batches to compute min/max across the entire group
all_data = []
key_val = None
for batch in batches:
all_data.append(batch)
if key_val is None:
key_val = batch.key.iloc[0]
combined = pd.concat(all_data, ignore_index=True)
assert len(combined) == 100000 / 2, len(combined)
yield pd.DataFrame(
{
"key": [key_val],
"min": [combined.v.min()],
"max": [combined.v.max()],
}
)
expected = (
df.groupby("key")
.agg(
sf.min("v").alias("min"),
sf.max("v").alias("max"),
)
.sort("key")
).collect()
for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 4096), (1000, 4096)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
"spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords,
"spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
}
):
result = (
df.groupBy("key")
.applyInPandas(min_max_v, "key long, min long, max long")
.sort("key")
).collect()
self.assertEqual(expected, result)
def test_apply_in_pandas_iterator_with_keys_batch_slicing(self):
df = self.spark.range(100000).select(
(sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
)
cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
df = df.withColumns(cols)
def min_max_v(
key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
) -> Iterator[pd.DataFrame]:
# Collect all batches to compute min/max across the entire group
all_data = []
for batch in batches:
all_data.append(batch)
combined = pd.concat(all_data, ignore_index=True)
assert len(combined) == 100000 / 2, len(combined)
yield pd.DataFrame(
{
"key": [key[0]],
"min": [combined.v.min()],
"max": [combined.v.max()],
}
)
expected = (
df.groupby("key").agg(sf.min("v").alias("min"), sf.max("v").alias("max")).sort("key")
).collect()
for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 4096), (1000, 4096)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
"spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords,
"spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
}
):
result = (
df.groupBy("key")
.applyInPandas(min_max_v, "key long, min long, max long")
.sort("key")
).collect()
self.assertEqual(expected, result)
def test_apply_in_pandas_iterator_multiple_output_batches(self):
df = self.spark.createDataFrame(
[(1, 1.0), (1, 2.0), (1, 3.0), (2, 4.0), (2, 5.0), (2, 6.0)], ("id", "v")
)
def split_and_yield(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
# Yield multiple output batches for each input batch
for batch in batches:
for _, row in batch.iterrows():
# Yield each row as a separate batch to test multiple yields
yield pd.DataFrame(
{"id": [row["id"]], "v": [row["v"]], "v_doubled": [row["v"] * 2]}
)
result = (
df.groupby("id")
.applyInPandas(split_and_yield, schema="id long, v double, v_doubled double")
.orderBy("id", "v")
.collect()
)
expected = [
Row(id=1, v=1.0, v_doubled=2.0),
Row(id=1, v=2.0, v_doubled=4.0),
Row(id=1, v=3.0, v_doubled=6.0),
Row(id=2, v=4.0, v_doubled=8.0),
Row(id=2, v=5.0, v_doubled=10.0),
Row(id=2, v=6.0, v_doubled=12.0),
]
self.assertEqual(result, expected)
def test_apply_in_pandas_iterator_filter_multiple_batches(self):
df = self.spark.createDataFrame(
[(1, i * 1.0) for i in range(20)] + [(2, i * 1.0) for i in range(20)], ("id", "v")
)
def filter_and_yield(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
# Yield filtered results from each batch
for batch in batches:
# Filter even values and yield
even_batch = batch[batch["v"] % 2 == 0]
if not even_batch.empty:
yield even_batch
# Filter odd values and yield separately
odd_batch = batch[batch["v"] % 2 == 1]
if not odd_batch.empty:
yield odd_batch
result = (
df.groupby("id")
.applyInPandas(filter_and_yield, schema="id long, v double")
.orderBy("id", "v")
.collect()
)
# Verify all 40 rows are present (20 per group)
self.assertEqual(len(result), 40)
# Verify group 1 has all values 0-19
group1 = [row for row in result if row[0] == 1]
self.assertEqual(len(group1), 20)
self.assertEqual([row[1] for row in group1], [float(i) for i in range(20)])
# Verify group 2 has all values 0-19
group2 = [row for row in result if row[0] == 2]
self.assertEqual(len(group2), 20)
self.assertEqual([row[1] for row in group2], [float(i) for i in range(20)])
def test_apply_in_pandas_iterator_with_keys_multiple_batches(self):
df = self.spark.createDataFrame(
[
(1, "a", 1.0),
(1, "b", 2.0),
(1, "c", 3.0),
(2, "d", 4.0),
(2, "e", 5.0),
(2, "f", 6.0),
],
("id", "name", "v"),
)
def process_with_key(
key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
) -> Iterator[pd.DataFrame]:
# Yield multiple processed batches, including the key in each output
for batch in batches:
# Split batch and yield multiple output batches
for chunk_size in [1, 2]:
for i in range(0, len(batch), chunk_size):
chunk = batch.iloc[i : i + chunk_size]
if not chunk.empty:
result = chunk.assign(id=key[0], total=chunk["v"].sum())
yield result[["id", "name", "total"]]
result = (
df.groupby("id")
.applyInPandas(process_with_key, schema="id long, name string, total double")
.orderBy("id", "name")
.collect()
)
# Verify we get results (may have duplicates due to splitting)
self.assertTrue(len(result) > 6)
# Verify all original names are present
names = [row[1] for row in result]
self.assertIn("a", names)
self.assertIn("b", names)
self.assertIn("c", names)
self.assertIn("d", names)
self.assertIn("e", names)
self.assertIn("f", names)
# Verify keys are correct
for row in result:
if row[1] in ["a", "b", "c"]:
self.assertEqual(row[0], 1)
else:
self.assertEqual(row[0], 2)
def test_apply_in_pandas_iterator_process_multiple_input_batches(self):
# Create large dataset to trigger batch slicing
df = self.spark.range(100000).select(
(sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
)
def process_batches_progressively(
batches: Iterator[pd.DataFrame],
) -> Iterator[pd.DataFrame]:
# Process each input batch and yield output immediately
batch_count = 0
for batch in batches:
batch_count += 1
# Yield a summary for each input batch processed
yield pd.DataFrame(
{
"key": [batch.key.iloc[0]],
"batch_num": [batch_count],
"count": [len(batch)],
"sum": [batch.v.sum()],
}
)
# Use small batch size to force multiple input batches
with self.sql_conf(
{
"spark.sql.execution.arrow.maxRecordsPerBatch": 10000,
}
):
result = (
df.groupBy("key")
.applyInPandas(
process_batches_progressively,
schema="key long, batch_num long, count long, sum long",
)
.orderBy("key", "batch_num")
.collect()
)
# Verify we got multiple batches per group (100000/2 = 50000 rows per group)
# With maxRecordsPerBatch=10000, should get 5 batches per group
group_0_batches = [r for r in result if r[0] == 0]
group_1_batches = [r for r in result if r[0] == 1]
# Verify multiple batches were processed
self.assertGreater(len(group_0_batches), 1)
self.assertGreater(len(group_1_batches), 1)
# Verify the sum across all batches equals expected total (using Python's built-in sum)
group_0_sum = sum(r[3] for r in group_0_batches)
group_1_sum = sum(r[3] for r in group_1_batches)
# Expected: sum of even numbers 0,2,4,...,99998
expected_even_sum = sum(range(0, 100000, 2))
expected_odd_sum = sum(range(1, 100000, 2))
self.assertEqual(group_0_sum, expected_even_sum)
self.assertEqual(group_1_sum, expected_odd_sum)
def test_apply_in_pandas_iterator_streaming_aggregation(self):
# Create dataset with multiple batches per group
df = self.spark.range(50000).select(
(sf.col("id") % 3).alias("key"),
(sf.col("id") % 100).alias("category"),
sf.col("id").alias("value"),
)
def streaming_aggregate(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
# Maintain running aggregates and yield intermediate results
running_sum = 0
running_count = 0
for batch in batches:
# Update running aggregates
running_sum += batch.value.sum()
running_count += len(batch)
# Yield current stats after processing each batch
yield pd.DataFrame(
{
"key": [batch.key.iloc[0]],
"running_count": [running_count],
"running_avg": [running_sum / running_count],
}
)
# Force multiple batches with small batch size
with self.sql_conf(
{
"spark.sql.execution.arrow.maxRecordsPerBatch": 5000,
}
):
result = (
df.groupBy("key")
.applyInPandas(
streaming_aggregate, schema="key long, running_count long, running_avg double"
)
.collect()
)
# Verify we got multiple rows per group (one per input batch)
for key_val in [0, 1, 2]:
key_results = [r for r in result if r[0] == key_val]
# Should have multiple batches
# (50000/3 ≈ 16667 rows per group, with 5000 per batch = ~4 batches)
self.assertGreater(len(key_results), 1, f"Expected multiple batches for key {key_val}")
# Verify running_count increases monotonically
counts = [r[1] for r in key_results]
for i in range(1, len(counts)):
self.assertGreater(
counts[i], counts[i - 1], "Running count should increase with each batch"
)
def test_apply_in_pandas_iterator_partial_iteration(self):
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 2}):
def func(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
# Only consume the first batch from the iterator
first = next(batches)
yield pd.DataFrame({"value": first["id"] % 4})
df = self.spark.range(20)
grouped_df = df.groupBy((sf.col("id") % 4).cast("int"))
# Should get two records for each group (first batch only)
expected = [Row(value=x) for x in [0, 0, 1, 1, 2, 2, 3, 3]]
actual = grouped_df.applyInPandas(func, "value long").collect()
self.assertEqual(actual, expected)
def test_grouped_map_pandas_udf_with_compression_codec(self):
# Test grouped map Pandas UDF with different compression codec settings
@pandas_udf("id long, v int, v1 double", PandasUDFType.GROUPED_MAP)
def foo(pdf):
return pdf.assign(v1=pdf.v * pdf.id * 1.0)
df = self.data
pdf = df.toPandas()
expected = pdf.groupby("id", as_index=False).apply(foo.func).reset_index(drop=True)
for codec in ["none", "zstd", "lz4"]:
with self.subTest(compressionCodec=codec):
with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
result = df.groupby("id").apply(foo).sort("id").toPandas()
assert_frame_equal(expected, result)
def test_apply_in_pandas_with_compression_codec(self):
# Test applyInPandas with different compression codec settings
def stats(key, pdf):
return pd.DataFrame([(key[0], pdf.v.mean())], columns=["id", "mean"])
df = self.data
expected = df.select("id").distinct().withColumn("mean", sf.lit(24.5)).toPandas()
for codec in ["none", "zstd", "lz4"]:
with self.subTest(compressionCodec=codec):
with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
result = (
df.groupby("id")
.applyInPandas(stats, schema="id long, mean double")
.sort("id")
.toPandas()
)
assert_frame_equal(expected, result)
def test_apply_in_pandas_iterator_with_compression_codec(self):
# Test applyInPandas with iterator and compression
df = self.spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
)
def sum_func(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
total = 0
for batch in batches:
total += batch["v"].sum()
yield pd.DataFrame({"v": [total]})
expected = [Row(v=3.0), Row(v=18.0)]
for codec in ["none", "zstd", "lz4"]:
with self.subTest(compressionCodec=codec):
with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
result = (
df.groupby("id")
.applyInPandas(sum_func, schema="v double")
.orderBy("v")
.collect()
)
self.assertEqual(result, expected)
class ApplyInPandasTests(ApplyInPandasTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.testing import main
main()