blob: 1e86e12eb74f099c6b806ad4707151b7a679ca05 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import unittest
from collections import OrderedDict
from decimal import Decimal
from typing import cast
from pyspark.sql import Row
from pyspark.sql.functions import (
array,
explode,
col,
lit,
udf,
sum,
pandas_udf,
PandasUDFType,
window,
)
from pyspark.sql.types import (
IntegerType,
DoubleType,
ArrayType,
BinaryType,
ByteType,
LongType,
DecimalType,
ShortType,
FloatType,
StringType,
BooleanType,
StructType,
StructField,
NullType,
MapType,
YearMonthIntervalType,
)
from pyspark.errors import PythonException, PySparkTypeError, PySparkValueError
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)
if have_pandas:
import pandas as pd
from pandas.testing import assert_frame_equal
if have_pyarrow:
import pyarrow as pa # noqa: F401
@unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
class GroupedApplyInPandasTestsMixin:
@property
def data(self):
return (
self.spark.range(10)
.withColumn("vs", array([lit(i) for i in range(20, 30)]))
.withColumn("v", explode(col("vs")))
.drop("vs")
)
def test_supported_types(self):
values = [
1,
2,
3,
4,
5,
1.1,
2.2,
Decimal(1.123),
[1, 2, 2],
True,
"hello",
bytearray([0x01, 0x02]),
None,
]
output_fields = [
("id", IntegerType()),
("byte", ByteType()),
("short", ShortType()),
("int", IntegerType()),
("long", LongType()),
("float", FloatType()),
("double", DoubleType()),
("decim", DecimalType(10, 3)),
("array", ArrayType(IntegerType())),
("bool", BooleanType()),
("str", StringType()),
("bin", BinaryType()),
("null", NullType()),
]
output_schema = StructType([StructField(*x) for x in output_fields])
df = self.spark.createDataFrame([values], schema=output_schema)
# Different forms of group map pandas UDF, results of these are the same
udf1 = pandas_udf(
lambda pdf: pdf.assign(
byte=pdf.byte * 2,
short=pdf.short * 2,
int=pdf.int * 2,
long=pdf.long * 2,
float=pdf.float * 2,
double=pdf.double * 2,
decim=pdf.decim * 2,
bool=False if pdf.bool else True,
str=pdf.str + "there",
array=pdf.array,
bin=pdf.bin,
null=pdf.null,
),
output_schema,
PandasUDFType.GROUPED_MAP,
)
udf2 = pandas_udf(
lambda _, pdf: pdf.assign(
byte=pdf.byte * 2,
short=pdf.short * 2,
int=pdf.int * 2,
long=pdf.long * 2,
float=pdf.float * 2,
double=pdf.double * 2,
decim=pdf.decim * 2,
bool=False if pdf.bool else True,
str=pdf.str + "there",
array=pdf.array,
bin=pdf.bin,
null=pdf.null,
),
output_schema,
PandasUDFType.GROUPED_MAP,
)
udf3 = pandas_udf(
lambda key, pdf: pdf.assign(
id=key[0],
byte=pdf.byte * 2,
short=pdf.short * 2,
int=pdf.int * 2,
long=pdf.long * 2,
float=pdf.float * 2,
double=pdf.double * 2,
decim=pdf.decim * 2,
bool=False if pdf.bool else True,
str=pdf.str + "there",
array=pdf.array,
bin=pdf.bin,
null=pdf.null,
),
output_schema,
PandasUDFType.GROUPED_MAP,
)
result1 = df.groupby("id").apply(udf1).sort("id").toPandas()
expected1 = df.toPandas().groupby("id").apply(udf1.func).reset_index(drop=True)
result2 = df.groupby("id").apply(udf2).sort("id").toPandas()
expected2 = expected1
result3 = df.groupby("id").apply(udf3).sort("id").toPandas()
expected3 = expected1
assert_frame_equal(expected1, result1)
assert_frame_equal(expected2, result2)
assert_frame_equal(expected3, result3)
def test_array_type_correct(self):
df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id")
output_schema = StructType(
[
StructField("id", LongType()),
StructField("v", IntegerType()),
StructField("arr", ArrayType(LongType())),
]
)
udf = pandas_udf(lambda pdf: pdf, output_schema, PandasUDFType.GROUPED_MAP)
result = df.groupby("id").apply(udf).sort("id").toPandas()
expected = df.toPandas().groupby("id").apply(udf.func).reset_index(drop=True)
assert_frame_equal(expected, result)
def test_register_grouped_map_udf(self):
with self.quiet():
self.check_register_grouped_map_udf()
def check_register_grouped_map_udf(self):
foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
with self.assertRaises(PySparkTypeError) as pe:
self.spark.catalog.registerFunction("foo_udf", foo_udf)
self.check_error(
exception=pe.exception,
error_class="INVALID_UDF_EVAL_TYPE",
message_parameters={
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
"SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF"
},
)
def test_decorator(self):
df = self.data
@pandas_udf("id long, v int, v1 double, v2 long", PandasUDFType.GROUPED_MAP)
def foo(pdf):
return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id)
result = df.groupby("id").apply(foo).sort("id").toPandas()
expected = df.toPandas().groupby("id").apply(foo.func).reset_index(drop=True)
assert_frame_equal(expected, result)
def test_coerce(self):
df = self.data
foo = pandas_udf(lambda pdf: pdf, "id long, v double", PandasUDFType.GROUPED_MAP)
result = df.groupby("id").apply(foo).sort("id").toPandas()
expected = df.toPandas().groupby("id").apply(foo.func).reset_index(drop=True)
expected = expected.assign(v=expected.v.astype("float64"))
assert_frame_equal(expected, result)
def test_complex_groupby(self):
df = self.data
@pandas_udf("id long, v int, norm double", PandasUDFType.GROUPED_MAP)
def normalize(pdf):
v = pdf.v
return pdf.assign(norm=(v - v.mean()) / v.std())
result = df.groupby(col("id") % 2 == 0).apply(normalize).sort("id", "v").toPandas()
pdf = df.toPandas()
expected = pdf.groupby(pdf["id"] % 2 == 0, as_index=False).apply(normalize.func)
expected = expected.sort_values(["id", "v"]).reset_index(drop=True)
expected = expected.assign(norm=expected.norm.astype("float64"))
assert_frame_equal(expected, result)
def test_empty_groupby(self):
df = self.data
@pandas_udf("id long, v int, norm double", PandasUDFType.GROUPED_MAP)
def normalize(pdf):
v = pdf.v
return pdf.assign(norm=(v - v.mean()) / v.std())
result = df.groupby().apply(normalize).sort("id", "v").toPandas()
pdf = df.toPandas()
expected = normalize.func(pdf)
expected = expected.sort_values(["id", "v"]).reset_index(drop=True)
expected = expected.assign(norm=expected.norm.astype("float64"))
assert_frame_equal(expected, result)
def test_apply_in_pandas_not_returning_pandas_dataframe(self):
with self.quiet():
self.check_apply_in_pandas_not_returning_pandas_dataframe()
def check_apply_in_pandas_not_returning_pandas_dataframe(self):
with self.assertRaisesRegex(
PythonException,
"Return type of the user-defined function should be pandas.DataFrame, but is tuple.",
):
self._test_apply_in_pandas(lambda key, pdf: key)
@staticmethod
def stats_with_column_names(key, pdf):
# order of column can be different to applyInPandas schema when column names are given
return pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"])
@staticmethod
def stats_with_no_column_names(key, pdf):
# columns must be in order of applyInPandas schema when no columns given
return pd.DataFrame([key + (pdf.v.mean(),)])
def test_apply_in_pandas_returning_column_names(self):
self._test_apply_in_pandas(GroupedApplyInPandasTestsMixin.stats_with_column_names)
def test_apply_in_pandas_returning_no_column_names(self):
self._test_apply_in_pandas(GroupedApplyInPandasTestsMixin.stats_with_no_column_names)
def test_apply_in_pandas_returning_column_names_sometimes(self):
def stats(key, pdf):
if key[0] % 2:
return GroupedApplyInPandasTestsMixin.stats_with_column_names(key, pdf)
else:
return GroupedApplyInPandasTestsMixin.stats_with_no_column_names(key, pdf)
self._test_apply_in_pandas(stats)
def test_apply_in_pandas_returning_wrong_column_names(self):
with self.quiet():
self.check_apply_in_pandas_returning_wrong_column_names()
def check_apply_in_pandas_returning_wrong_column_names(self):
with self.assertRaisesRegex(
PythonException,
"Column names of the returned pandas.DataFrame do not match specified schema. "
"Missing: mean. Unexpected: median, std.\n",
):
self._test_apply_in_pandas(
lambda key, pdf: pd.DataFrame(
[key + (pdf.v.median(), pdf.v.std())], columns=["id", "median", "std"]
)
)
def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
with self.quiet():
self.check_apply_in_pandas_returning_no_column_names_and_wrong_amount()
def check_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
with self.assertRaisesRegex(
PythonException,
"Number of columns of the returned pandas.DataFrame doesn't match "
"specified schema. Expected: 2 Actual: 3\n",
):
self._test_apply_in_pandas(
lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(), pdf.v.std())])
)
def test_apply_in_pandas_returning_empty_dataframe(self):
self._test_apply_in_pandas_returning_empty_dataframe(pd.DataFrame())
def test_apply_in_pandas_returning_incompatible_type(self):
with self.quiet():
self.check_apply_in_pandas_returning_incompatible_type()
def check_apply_in_pandas_returning_incompatible_type(self):
for safely in [True, False]:
with self.subTest(convertToArrowArraySafely=safely), self.sql_conf(
{"spark.sql.execution.pandas.convertToArrowArraySafely": safely}
):
# sometimes we see ValueErrors
with self.subTest(convert="string to double"):
expected = (
r"ValueError: Exception thrown when converting pandas.Series \(object\) "
r"with name 'mean' to Arrow Array \(double\)."
)
if safely:
expected = expected + (
" It can be caused by overflows or other "
"unsafe conversions warned by Arrow. Arrow safe type check "
"can be disabled by using SQL config "
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
with self.assertRaisesRegex(PythonException, expected + "\n"):
self._test_apply_in_pandas(
lambda key, pdf: pd.DataFrame([key + (str(pdf.v.mean()),)]),
output_schema="id long, mean double",
)
# sometimes we see TypeErrors
with self.subTest(convert="double to string"):
with self.assertRaisesRegex(
PythonException,
r"TypeError: Exception thrown when converting pandas.Series \(float64\) "
r"with name 'mean' to Arrow Array \(string\).\n",
):
self._test_apply_in_pandas(
lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(),)]),
output_schema="id long, mean string",
)
def test_datatype_string(self):
df = self.data
foo_udf = pandas_udf(
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
"id long, v int, v1 double, v2 long",
PandasUDFType.GROUPED_MAP,
)
result = df.groupby("id").apply(foo_udf).sort("id").toPandas()
expected = df.toPandas().groupby("id").apply(foo_udf.func).reset_index(drop=True)
assert_frame_equal(expected, result)
def test_wrong_return_type(self):
with self.quiet():
self.check_wrong_return_type()
def check_wrong_return_type(self):
with self.assertRaisesRegex(
NotImplementedError,
"Invalid return type.*grouped map Pandas UDF.*ArrayType.*YearMonthIntervalType",
):
pandas_udf(
lambda pdf: pdf,
StructType().add("id", LongType()).add("v", ArrayType(YearMonthIntervalType())),
PandasUDFType.GROUPED_MAP,
)
def test_wrong_args(self):
with self.quiet():
self.check_wrong_args()
def check_wrong_args(self):
df = self.data
with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
df.groupby("id").apply(lambda x: x)
with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
df.groupby("id").apply(udf(lambda x: x, DoubleType()))
with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
df.groupby("id").apply(sum(df.v))
with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
df.groupby("id").apply(df.v + 1)
with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
df.groupby("id").apply(pandas_udf(lambda x, y: x, DoubleType()))
with self.assertRaisesRegex(PySparkTypeError, "INVALID_UDF_EVAL_TYPE"):
df.groupby("id").apply(pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
with self.assertRaisesRegex(PySparkValueError, "INVALID_PANDAS_UDF"):
df.groupby("id").apply(
pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))
)
def test_unsupported_types(self):
with self.quiet():
self.check_unsupported_types()
def check_unsupported_types(self):
common_err_msg = "Invalid return type.*grouped map Pandas UDF.*"
unsupported_types = [
StructField("array_struct", ArrayType(YearMonthIntervalType())),
StructField("map", MapType(StringType(), YearMonthIntervalType())),
]
for unsupported_type in unsupported_types:
with self.subTest(unsupported_type=unsupported_type.name):
schema = StructType([StructField("id", LongType(), True), unsupported_type])
with self.assertRaisesRegex(NotImplementedError, common_err_msg):
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
# Regression test for SPARK-23314
def test_timestamp_dst(self):
# Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
dt = [
datetime.datetime(2015, 11, 1, 0, 30),
datetime.datetime(2015, 11, 1, 1, 30),
datetime.datetime(2015, 11, 1, 2, 30),
]
df = self.spark.createDataFrame(dt, "timestamp").toDF("time")
foo_udf = pandas_udf(lambda pdf: pdf, "time timestamp", PandasUDFType.GROUPED_MAP)
result = df.groupby("time").apply(foo_udf).sort("time")
assert_frame_equal(df.toPandas(), result.toPandas())
def test_udf_with_key(self):
import numpy as np
df = self.data
pdf = df.toPandas()
def foo1(key, pdf):
assert type(key) == tuple
assert type(key[0]) == np.int64
return pdf.assign(
v1=key[0], v2=pdf.v * key[0], v3=pdf.v * pdf.id, v4=pdf.v * pdf.id.mean()
)
def foo2(key, pdf):
assert type(key) == tuple
assert type(key[0]) == np.int64
assert type(key[1]) == np.int32
return pdf.assign(v1=key[0], v2=key[1], v3=pdf.v * key[0], v4=pdf.v + key[1])
def foo3(key, pdf):
assert type(key) == tuple
assert len(key) == 0
return pdf.assign(v1=pdf.v * pdf.id)
# v2 is int because numpy.int64 * pd.Series<int32> results in pd.Series<int32>
# v3 is long because pd.Series<int64> * pd.Series<int32> results in pd.Series<int64>
udf1 = pandas_udf(
foo1, "id long, v int, v1 long, v2 int, v3 long, v4 double", PandasUDFType.GROUPED_MAP
)
udf2 = pandas_udf(
foo2, "id long, v int, v1 long, v2 int, v3 int, v4 int", PandasUDFType.GROUPED_MAP
)
udf3 = pandas_udf(foo3, "id long, v int, v1 long", PandasUDFType.GROUPED_MAP)
# Test groupby column
result1 = df.groupby("id").apply(udf1).sort("id", "v").toPandas()
expected1 = (
pdf.groupby("id", as_index=False)
.apply(lambda x: udf1.func((x.id.iloc[0],), x))
.sort_values(["id", "v"])
.reset_index(drop=True)
)
assert_frame_equal(expected1, result1)
# Test groupby expression
result2 = df.groupby(df.id % 2).apply(udf1).sort("id", "v").toPandas()
expected2 = (
pdf.groupby(pdf.id % 2, as_index=False)
.apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))
.sort_values(["id", "v"])
.reset_index(drop=True)
)
assert_frame_equal(expected2, result2)
# Test complex groupby
result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort("id", "v").toPandas()
expected3 = (
pdf.groupby([pdf.id, pdf.v % 2], as_index=False)
.apply(
lambda x: udf2.func(
(
x.id.iloc[0],
(x.v % 2).iloc[0],
),
x,
)
)
.sort_values(["id", "v"])
.reset_index(drop=True)
)
assert_frame_equal(expected3, result3)
# Test empty groupby
result4 = df.groupby().apply(udf3).sort("id", "v").toPandas()
expected4 = udf3.func((), pdf)
assert_frame_equal(expected4, result4)
def test_column_order(self):
with self.quiet():
self.check_column_order()
def check_column_order(self):
# Helper function to set column names from a list
def rename_pdf(pdf, names):
pdf.rename(
columns={old: new for old, new in zip(pd_result.columns, names)}, inplace=True
)
df = self.data
grouped_df = df.groupby("id")
grouped_pdf = df.toPandas().groupby("id", as_index=False)
# Function returns a pdf with required column names, but order could be arbitrary using dict
def change_col_order(pdf):
# Constructing a DataFrame from a dict should result in the same order,
# but use OrderedDict to ensure the pdf column order is different than schema
return pd.DataFrame.from_dict(
OrderedDict([("id", pdf.id), ("u", pdf.v * 2), ("v", pdf.v)])
)
ordered_udf = pandas_udf(
change_col_order, "id long, v int, u int", PandasUDFType.GROUPED_MAP
)
# The UDF result should assign columns by name from the pdf
result = grouped_df.apply(ordered_udf).sort("id", "v").select("id", "u", "v").toPandas()
pd_result = grouped_pdf.apply(change_col_order)
expected = pd_result.sort_values(["id", "v"]).reset_index(drop=True)
assert_frame_equal(expected, result)
# Function returns a pdf with positional columns, indexed by range
def range_col_order(pdf):
# Create a DataFrame with positional columns, fix types to long
return pd.DataFrame(list(zip(pdf.id, pdf.v * 3, pdf.v)), dtype="int64")
range_udf = pandas_udf(
range_col_order, "id long, u long, v long", PandasUDFType.GROUPED_MAP
)
# The UDF result uses positional columns from the pdf
result = grouped_df.apply(range_udf).sort("id", "v").select("id", "u", "v").toPandas()
pd_result = grouped_pdf.apply(range_col_order)
rename_pdf(pd_result, ["id", "u", "v"])
expected = pd_result.sort_values(["id", "v"]).reset_index(drop=True)
assert_frame_equal(expected, result)
# Function returns a pdf with columns indexed with integers
def int_index(pdf):
return pd.DataFrame(OrderedDict([(0, pdf.id), (1, pdf.v * 4), (2, pdf.v)]))
int_index_udf = pandas_udf(int_index, "id long, u int, v int", PandasUDFType.GROUPED_MAP)
# The UDF result should assign columns by position of integer index
result = grouped_df.apply(int_index_udf).sort("id", "v").select("id", "u", "v").toPandas()
pd_result = grouped_pdf.apply(int_index)
rename_pdf(pd_result, ["id", "u", "v"])
expected = pd_result.sort_values(["id", "v"]).reset_index(drop=True)
assert_frame_equal(expected, result)
@pandas_udf("id long, v int", PandasUDFType.GROUPED_MAP)
def column_name_typo(pdf):
return pd.DataFrame({"iid": pdf.id, "v": pdf.v})
@pandas_udf("id long, v decimal", PandasUDFType.GROUPED_MAP)
def invalid_positional_types(pdf):
return pd.DataFrame([(1, datetime.date(2020, 10, 5))])
with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
with self.assertRaisesRegex(
PythonException,
"Column names of the returned pandas.DataFrame do not match "
"specified schema. Missing: id. Unexpected: iid.\n",
):
grouped_df.apply(column_name_typo).collect()
with self.assertRaisesRegex(Exception, "[D|d]ecimal.*got.*date"):
grouped_df.apply(invalid_positional_types).collect()
def test_positional_assignment_conf(self):
with self.sql_conf(
{"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}
):
@pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP)
def foo(_):
return pd.DataFrame([("hi", 1)], columns=["x", "y"])
df = self.data
result = df.groupBy("id").apply(foo).select("a", "b").collect()
for r in result:
self.assertEqual(r.a, "hi")
self.assertEqual(r.b, 1)
def test_self_join_with_pandas(self):
@pandas_udf("key long, col string", PandasUDFType.GROUPED_MAP)
def dummy_pandas_udf(df):
return df[["key", "col"]]
df = self.spark.createDataFrame(
[Row(key=1, col="A"), Row(key=1, col="B"), Row(key=2, col="C")]
)
df_with_pandas = df.groupBy("key").apply(dummy_pandas_udf)
# this was throwing an AnalysisException before SPARK-24208
res = df_with_pandas.alias("temp0").join(
df_with_pandas.alias("temp1"), col("temp0.key") == col("temp1.key")
)
self.assertEqual(res.count(), 5)
def test_mixed_scalar_udfs_followed_by_groupby_apply(self):
df = self.spark.range(0, 10).toDF("v1")
df = df.withColumn("v2", udf(lambda x: x + 1, "int")(df["v1"])).withColumn(
"v3", pandas_udf(lambda x: x + 2, "int")(df["v1"])
)
result = df.groupby().apply(
pandas_udf(
lambda x: pd.DataFrame([x.sum().sum()]), "sum int", PandasUDFType.GROUPED_MAP
)
)
self.assertEqual(result.collect()[0]["sum"], 165)
def test_grouped_with_empty_partition(self):
data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
expected = [Row(id=1, x=5), Row(id=1, x=5), Row(id=2, x=4)]
num_parts = len(data) + 1
df = self.spark.createDataFrame(data).repartition(num_parts)
f = pandas_udf(
lambda pdf: pdf.assign(x=pdf["x"].sum()), "id long, x int", PandasUDFType.GROUPED_MAP
)
result = df.groupBy("id").apply(f).sort("id").collect()
self.assertEqual(result, expected)
def test_grouped_over_window(self):
data = [
(0, 1, "2018-03-10T00:00:00+00:00", [0]),
(1, 2, "2018-03-11T00:00:00+00:00", [0]),
(2, 2, "2018-03-12T00:00:00+00:00", [0]),
(3, 3, "2018-03-15T00:00:00+00:00", [0]),
(4, 3, "2018-03-16T00:00:00+00:00", [0]),
(5, 3, "2018-03-17T00:00:00+00:00", [0]),
(6, 3, "2018-03-21T00:00:00+00:00", [0]),
]
expected = {0: [0], 1: [1, 2], 2: [1, 2], 3: [3, 4, 5], 4: [3, 4, 5], 5: [3, 4, 5], 6: [6]}
df = self.spark.createDataFrame(data, ["id", "group", "ts", "result"])
df = df.select(col("id"), col("group"), col("ts").cast("timestamp"), col("result"))
def f(pdf):
# Assign each result element the ids of the windowed group
pdf["result"] = [pdf["id"]] * len(pdf)
return pdf
result = (
df.groupby("group", window("ts", "5 days"))
.applyInPandas(f, df.schema)
.select("id", "result")
.orderBy("id")
.collect()
)
self.assertListEqual([Row(id=key, result=val) for key, val in expected.items()], result)
def test_grouped_over_window_with_key(self):
data = [
(0, 1, "2018-03-10T00:00:00+00:00", [0]),
(1, 2, "2018-03-11T00:00:00+00:00", [0]),
(2, 2, "2018-03-12T00:00:00+00:00", [0]),
(3, 3, "2018-03-15T00:00:00+00:00", [0]),
(4, 3, "2018-03-16T00:00:00+00:00", [0]),
(5, 3, "2018-03-17T00:00:00+00:00", [0]),
(6, 3, "2018-03-21T00:00:00+00:00", [0]),
]
timezone = self.spark.conf.get("spark.sql.session.timeZone")
expected_window = [
{
key: (
pd.Timestamp(ts)
.tz_localize(datetime.timezone.utc)
.tz_convert(timezone)
.tz_localize(None)
)
for key, ts in w.items()
}
for w in [
{
"start": datetime.datetime(2018, 3, 10, 0, 0),
"end": datetime.datetime(2018, 3, 15, 0, 0),
},
{
"start": datetime.datetime(2018, 3, 15, 0, 0),
"end": datetime.datetime(2018, 3, 20, 0, 0),
},
{
"start": datetime.datetime(2018, 3, 20, 0, 0),
"end": datetime.datetime(2018, 3, 25, 0, 0),
},
]
]
expected_key = {
0: (1, expected_window[0]),
1: (2, expected_window[0]),
2: (2, expected_window[0]),
3: (3, expected_window[1]),
4: (3, expected_window[1]),
5: (3, expected_window[1]),
6: (3, expected_window[2]),
}
# id -> array of group with len of num records in window
expected = {0: [1], 1: [2, 2], 2: [2, 2], 3: [3, 3, 3], 4: [3, 3, 3], 5: [3, 3, 3], 6: [3]}
df = self.spark.createDataFrame(data, ["id", "group", "ts", "result"])
df = df.select(col("id"), col("group"), col("ts").cast("timestamp"), col("result"))
def f(key, pdf):
group = key[0]
window_range = key[1]
# Make sure the key with group and window values are correct
for _, i in pdf.id.items():
assert expected_key[i][0] == group, "{} != {}".format(expected_key[i][0], group)
assert expected_key[i][1] == window_range, "{} != {}".format(
expected_key[i][1], window_range
)
return pdf.assign(result=[[group] * len(pdf)] * len(pdf))
result = (
df.groupby("group", window("ts", "5 days"))
.applyInPandas(f, df.schema)
.select("id", "result")
.orderBy("id")
.collect()
)
self.assertListEqual([Row(id=key, result=val) for key, val in expected.items()], result)
def test_case_insensitive_grouping_column(self):
# SPARK-31915: case-insensitive grouping column should work.
def my_pandas_udf(pdf):
return pdf.assign(score=0.5)
df = self.spark.createDataFrame([[1, 1]], ["column", "score"])
row = (
df.groupby("COLUMN")
.applyInPandas(my_pandas_udf, schema="column integer, score float")
.first()
)
self.assertEqual(row.asDict(), Row(column=1, score=0.5).asDict())
def _test_apply_in_pandas(self, f, output_schema="id long, mean double"):
df = self.data
result = (
df.groupby("id").applyInPandas(f, schema=output_schema).sort("id", "mean").toPandas()
)
expected = df.select("id").distinct().withColumn("mean", lit(24.5)).toPandas()
assert_frame_equal(expected, result)
def _test_apply_in_pandas_returning_empty_dataframe(self, empty_df):
"""Tests some returned DataFrames are empty."""
df = self.data
def stats(key, pdf):
if key[0] % 2 == 0:
return GroupedApplyInPandasTestsMixin.stats_with_no_column_names(key, pdf)
return empty_df
result = (
df.groupby("id")
.applyInPandas(stats, schema="id long, mean double")
.sort("id", "mean")
.collect()
)
actual_ids = {row[0] for row in result}
expected_ids = {row[0] for row in self.data.collect() if row[0] % 2 == 0}
self.assertSetEqual(expected_ids, actual_ids)
self.assertEqual(len(expected_ids), len(result))
for row in result:
self.assertEqual(24.5, row[1])
def _test_apply_in_pandas_returning_empty_dataframe_error(self, empty_df, error):
with self.quiet():
with self.assertRaisesRegex(PythonException, error):
self._test_apply_in_pandas_returning_empty_dataframe(empty_df)
class GroupedApplyInPandasTests(GroupedApplyInPandasTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.sql.tests.pandas.test_pandas_grouped_map import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)