blob: b1060ef48156a7cced1efd53eab3459c12377032 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from typing import cast
from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf, sum
from pyspark.sql.types import (
ArrayType,
DoubleType,
LongType,
StructType,
StructField,
YearMonthIntervalType,
Row,
)
from pyspark.sql.window import Window
from pyspark.errors import IllegalArgumentException, PythonException
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)
if have_pandas:
import pandas as pd
from pandas.testing import assert_frame_equal
if have_pyarrow:
import pyarrow as pa # noqa: F401
@unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
class CogroupedApplyInPandasTestsMixin:
@property
def data1(self):
return (
self.spark.range(10)
.withColumn("ks", array([lit(i) for i in range(20, 30)]))
.withColumn("k", explode(col("ks")))
.withColumn("v", col("k") * 10)
.drop("ks")
)
@property
def data2(self):
return (
self.spark.range(10)
.withColumn("ks", array([lit(i) for i in range(20, 30)]))
.withColumn("k", explode(col("ks")))
.withColumn("v2", col("k") * 100)
.drop("ks")
)
def test_simple(self):
self._test_merge(self.data1, self.data2)
def test_left_group_empty(self):
left = self.data1.where(col("id") % 2 == 0)
self._test_merge(left, self.data2)
def test_right_group_empty(self):
right = self.data2.where(col("id") % 2 == 0)
self._test_merge(self.data1, right)
def test_different_schemas(self):
right = self.data2.withColumn("v3", lit("a"))
self._test_merge(
self.data1, right, output_schema="id long, k int, v int, v2 int, v3 string"
)
def test_different_keys(self):
left = self.data1
right = self.data2
def merge_pandas(lft, rgt):
return pd.merge(lft.rename(columns={"id2": "id"}), rgt, on=["id", "k"])
result = (
left.withColumnRenamed("id", "id2")
.groupby("id2")
.cogroup(right.groupby("id"))
.applyInPandas(merge_pandas, "id long, k int, v int, v2 int")
.sort(["id", "k"])
.toPandas()
)
left = left.toPandas()
right = right.toPandas()
expected = pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", "k"])
assert_frame_equal(expected, result)
def test_complex_group_by(self):
left = pd.DataFrame.from_dict({"id": [1, 2, 3], "k": [5, 6, 7], "v": [9, 10, 11]})
right = pd.DataFrame.from_dict({"id": [11, 12, 13], "k": [5, 6, 7], "v2": [90, 100, 110]})
left_gdf = self.spark.createDataFrame(left).groupby(col("id") % 2 == 0)
right_gdf = self.spark.createDataFrame(right).groupby(col("id") % 2 == 0)
def merge_pandas(lft, rgt):
return pd.merge(lft[["k", "v"]], rgt[["k", "v2"]], on=["k"])
result = (
left_gdf.cogroup(right_gdf)
.applyInPandas(merge_pandas, "k long, v long, v2 long")
.sort(["k"])
.toPandas()
)
expected = pd.DataFrame.from_dict({"k": [5, 6, 7], "v": [9, 10, 11], "v2": [90, 100, 110]})
assert_frame_equal(expected, result)
def test_empty_group_by(self):
self._test_merge(self.data1, self.data2, by=[])
def test_different_group_key_cardinality(self):
with self.quiet():
self.check_different_group_key_cardinality()
def check_different_group_key_cardinality(self):
left = self.data1
right = self.data2
def merge_pandas(lft, _):
return lft
with self.assertRaisesRegex(
IllegalArgumentException,
"requirement failed: Cogroup keys must have same size: 2 != 1",
):
(left.groupby("id", "k").cogroup(right.groupby("id"))).applyInPandas(
merge_pandas, "id long, k int, v int"
).schema
def test_apply_in_pandas_not_returning_pandas_dataframe(self):
with self.quiet():
self.check_apply_in_pandas_not_returning_pandas_dataframe()
def check_apply_in_pandas_not_returning_pandas_dataframe(self):
self._test_merge_error(
fn=lambda lft, rgt: lft.size + rgt.size,
error_class=PythonException,
error_message_regex="Return type of the user-defined function "
"should be pandas.DataFrame, but is int",
)
def test_apply_in_pandas_returning_column_names(self):
self._test_merge(fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]))
def test_apply_in_pandas_returning_no_column_names(self):
def merge_pandas(lft, rgt):
res = pd.merge(lft, rgt, on=["id", "k"])
res.columns = range(res.columns.size)
return res
self._test_merge(fn=merge_pandas)
def test_apply_in_pandas_returning_column_names_sometimes(self):
def merge_pandas(lft, rgt):
res = pd.merge(lft, rgt, on=["id", "k"])
if 0 in lft["id"] and lft["id"][0] % 2 == 0:
return res
res.columns = range(res.columns.size)
return res
self._test_merge(fn=merge_pandas)
def test_apply_in_pandas_returning_wrong_column_names(self):
with self.quiet():
self.check_apply_in_pandas_returning_wrong_column_names()
def check_apply_in_pandas_returning_wrong_column_names(self):
def merge_pandas(lft, rgt):
if 0 in lft["id"] and lft["id"][0] % 2 == 0:
lft["add"] = 0
if 0 in rgt["id"] and rgt["id"][0] % 3 == 0:
rgt["more"] = 1
return pd.merge(lft, rgt, on=["id", "k"])
self._test_merge_error(
fn=merge_pandas,
error_class=PythonException,
error_message_regex="Column names of the returned pandas.DataFrame "
"do not match specified schema. Unexpected: add, more.\n",
)
def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
with self.quiet():
self.check_apply_in_pandas_returning_no_column_names_and_wrong_amount()
def check_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
def merge_pandas(lft, rgt):
if 0 in lft["id"] and lft["id"][0] % 2 == 0:
lft[3] = 0
if 0 in rgt["id"] and rgt["id"][0] % 3 == 0:
rgt[3] = 1
res = pd.merge(lft, rgt, on=["id", "k"])
res.columns = range(res.columns.size)
return res
self._test_merge_error(
fn=merge_pandas,
error_class=PythonException,
error_message_regex="Number of columns of the returned pandas.DataFrame "
"doesn't match specified schema. Expected: 4 Actual: 6\n",
)
def test_apply_in_pandas_returning_empty_dataframe(self):
def merge_pandas(lft, rgt):
if 0 in lft["id"] and lft["id"][0] % 2 == 0:
return pd.DataFrame()
if 0 in rgt["id"] and rgt["id"][0] % 3 == 0:
return pd.DataFrame()
return pd.merge(lft, rgt, on=["id", "k"])
self._test_merge_empty(fn=merge_pandas)
def test_apply_in_pandas_returning_incompatible_type(self):
with self.quiet():
self.check_apply_in_pandas_returning_incompatible_type()
def check_apply_in_pandas_returning_incompatible_type(self):
for safely in [True, False]:
with self.subTest(convertToArrowArraySafely=safely), self.sql_conf(
{"spark.sql.execution.pandas.convertToArrowArraySafely": safely}
):
# sometimes we see ValueErrors
with self.subTest(convert="string to double"):
expected = (
r"ValueError: Exception thrown when converting pandas.Series \(object\) "
r"with name 'k' to Arrow Array \(double\)."
)
if safely:
expected = expected + (
" It can be caused by overflows or other "
"unsafe conversions warned by Arrow. Arrow safe type check "
"can be disabled by using SQL config "
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
self._test_merge_error(
fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": ["2.0"]}),
output_schema="id long, k double",
error_class=PythonException,
error_message_regex=expected,
)
# sometimes we see TypeErrors
with self.subTest(convert="double to string"):
expected = (
r"TypeError: Exception thrown when converting pandas.Series \(float64\) "
r"with name 'k' to Arrow Array \(string\).\n"
)
self._test_merge_error(
fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": [2.0]}),
output_schema="id long, k string",
error_class=PythonException,
error_message_regex=expected,
)
def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self):
df = self.spark.range(0, 10).toDF("v1")
df = df.withColumn("v2", udf(lambda x: x + 1, "int")(df["v1"])).withColumn(
"v3", pandas_udf(lambda x: x + 2, "int")(df["v1"])
)
result = (
df.groupby()
.cogroup(df.groupby())
.applyInPandas(
lambda x, y: pd.DataFrame([(x.sum().sum(), y.sum().sum())]), "sum1 int, sum2 int"
)
.collect()
)
self.assertEqual(result[0]["sum1"], 165)
self.assertEqual(result[0]["sum2"], 165)
def test_with_key_left(self):
self._test_with_key(self.data1, self.data1, isLeft=True)
def test_with_key_right(self):
self._test_with_key(self.data1, self.data1, isLeft=False)
def test_with_key_left_group_empty(self):
left = self.data1.where(col("id") % 2 == 0)
self._test_with_key(left, self.data1, isLeft=True)
def test_with_key_right_group_empty(self):
right = self.data1.where(col("id") % 2 == 0)
self._test_with_key(self.data1, right, isLeft=False)
def test_with_key_complex(self):
def left_assign_key(key, lft, _):
return lft.assign(key=key[0])
result = (
self.data1.groupby(col("id") % 2 == 0)
.cogroup(self.data2.groupby(col("id") % 2 == 0))
.applyInPandas(left_assign_key, "id long, k int, v int, key boolean")
.sort(["id", "k"])
.toPandas()
)
expected = self.data1.toPandas()
expected = expected.assign(key=expected.id % 2 == 0)
assert_frame_equal(expected, result)
def test_wrong_return_type(self):
with self.quiet():
self.check_wrong_return_type()
def check_wrong_return_type(self):
# Test that we get a sensible exception invalid values passed to apply
self._test_merge_error(
fn=lambda l, r: l,
output_schema=(
StructType().add("id", LongType()).add("v", ArrayType(YearMonthIntervalType()))
),
error_class=NotImplementedError,
error_message_regex="Invalid return type.*ArrayType.*YearMonthIntervalType",
)
def test_wrong_args(self):
with self.quiet():
self.check_wrong_args()
def check_wrong_args(self):
self.__test_merge_error(
fn=lambda: 1,
output_schema=StructType([StructField("d", DoubleType())]),
error_class=ValueError,
error_message_regex="Invalid function",
)
def test_case_insensitive_grouping_column(self):
# SPARK-31915: case-insensitive grouping column should work.
df1 = self.spark.createDataFrame([(1, 1)], ("column", "value"))
row = (
df1.groupby("ColUmn")
.cogroup(df1.groupby("COLUMN"))
.applyInPandas(lambda r, l: r + l, "column long, value long")
.first()
)
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
df2 = self.spark.createDataFrame([(1, 1)], ("column", "value"))
row = (
df1.groupby("ColUmn")
.cogroup(df2.groupby("COLUMN"))
.applyInPandas(lambda r, l: r + l, "column long, value long")
.first()
)
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
def test_self_join(self):
# SPARK-34319: self-join with FlatMapCoGroupsInPandas
df = self.spark.createDataFrame([(1, 1)], ("column", "value"))
row = (
df.groupby("ColUmn")
.cogroup(df.groupby("COLUMN"))
.applyInPandas(lambda r, l: r + l, "column long, value long")
)
row = row.join(row).first()
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
def test_with_window_function(self):
# SPARK-42168: a window function with same partition keys but differing key order
ids = 2
days = 100
vals = 10000
parts = 10
id_df = self.spark.range(ids)
day_df = self.spark.range(days).withColumnRenamed("id", "day")
vals_df = self.spark.range(vals).withColumnRenamed("id", "value")
df = id_df.join(day_df).join(vals_df)
left_df = df.withColumnRenamed("value", "left").repartition(parts).cache()
# SPARK-42132: this bug requires us to alias all columns from df here
right_df = (
df.select(col("id").alias("id"), col("day").alias("day"), col("value").alias("right"))
.repartition(parts)
.cache()
)
# note the column order is different to the groupBy("id", "day") column order below
window = Window.partitionBy("day", "id")
left_grouped_df = left_df.groupBy("id", "day")
right_grouped_df = right_df.withColumn("day_sum", sum(col("day")).over(window)).groupBy(
"id", "day"
)
def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
return pd.DataFrame(
[
{
"id": left["id"][0]
if not left.empty
else (right["id"][0] if not right.empty else None),
"day": left["day"][0]
if not left.empty
else (right["day"][0] if not right.empty else None),
"lefts": len(left.index),
"rights": len(right.index),
}
]
)
df = left_grouped_df.cogroup(right_grouped_df).applyInPandas(
cogroup, schema="id long, day long, lefts integer, rights integer"
)
actual = df.orderBy("id", "day").take(days)
self.assertEqual(actual, [Row(0, day, vals, vals) for day in range(days)])
def test_with_local_data(self):
df1 = self.spark.createDataFrame(
[(1, 1.0, "a"), (2, 2.0, "b"), (1, 3.0, "c"), (2, 4.0, "d")], ("id", "v1", "v2")
)
df2 = self.spark.createDataFrame([(1, "x"), (2, "y"), (1, "z")], ("id", "v3"))
def summarize(left, right):
return pd.DataFrame(
{
"left_rows": [len(left)],
"left_columns": [len(left.columns)],
"right_rows": [len(right)],
"right_columns": [len(right.columns)],
}
)
df = (
df1.groupby("id")
.cogroup(df2.groupby("id"))
.applyInPandas(
summarize,
schema="left_rows long, left_columns long, right_rows long, right_columns long",
)
)
self.assertEqual(
df._show_string(),
"+---------+------------+----------+-------------+\n"
"|left_rows|left_columns|right_rows|right_columns|\n"
"+---------+------------+----------+-------------+\n"
"| 2| 3| 2| 2|\n"
"| 2| 3| 1| 2|\n"
"+---------+------------+----------+-------------+\n",
)
@staticmethod
def _test_with_key(left, right, isLeft):
def right_assign_key(key, lft, rgt):
return lft.assign(key=key[0]) if isLeft else rgt.assign(key=key[0])
result = (
left.groupby("id")
.cogroup(right.groupby("id"))
.applyInPandas(right_assign_key, "id long, k int, v int, key long")
.toPandas()
)
expected = left.toPandas() if isLeft else right.toPandas()
expected = expected.assign(key=expected.id)
assert_frame_equal(expected, result)
def _test_merge_empty(self, fn):
left = self.data1.toPandas()
right = self.data2.toPandas()
expected = pd.merge(
left[left["id"] % 2 != 0], right[right["id"] % 3 != 0], on=["id", "k"]
).sort_values(by=["id", "k"])
self._test_merge(self.data1, self.data2, fn=fn, expected=expected)
def _test_merge(
self,
left=None,
right=None,
by=["id"],
fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
output_schema="id long, k int, v int, v2 int",
expected=None,
):
def fn_with_key(_, lft, rgt):
return fn(lft, rgt)
# Test fn with and without key argument
with self.subTest("without key"):
self.__test_merge(left, right, by, fn, output_schema, expected)
with self.subTest("with key"):
self.__test_merge(left, right, by, fn_with_key, output_schema, expected)
def __test_merge(
self,
left=None,
right=None,
by=["id"],
fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
output_schema="id long, k int, v int, v2 int",
expected=None,
):
# Test fn as is, cf. _test_merge
left = self.data1 if left is None else left
right = self.data2 if right is None else right
result = (
left.groupby(*by)
.cogroup(right.groupby(*by))
.applyInPandas(fn, output_schema)
.sort(["id", "k"])
.toPandas()
)
left = left.toPandas()
right = right.toPandas()
expected = (
pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", "k"])
if expected is None
else expected
)
assert_frame_equal(expected, result)
def _test_merge_error(
self,
error_class,
error_message_regex,
left=None,
right=None,
by=["id"],
fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
output_schema="id long, k int, v int, v2 int",
):
def fn_with_key(_, lft, rgt):
return fn(lft, rgt)
# Test fn with and without key argument
with self.subTest("without key"):
self.__test_merge_error(
left=left,
right=right,
by=by,
fn=fn,
output_schema=output_schema,
error_class=error_class,
error_message_regex=error_message_regex,
)
with self.subTest("with key"):
self.__test_merge_error(
left=left,
right=right,
by=by,
fn=fn_with_key,
output_schema=output_schema,
error_class=error_class,
error_message_regex=error_message_regex,
)
def __test_merge_error(
self,
error_class,
error_message_regex,
left=None,
right=None,
by=["id"],
fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
output_schema="id long, k int, v int, v2 int",
):
# Test fn as is, cf. _test_merge_error
with self.assertRaisesRegex(error_class, error_message_regex):
self.__test_merge(left, right, by, fn, output_schema)
class CogroupedApplyInPandasTests(CogroupedApplyInPandasTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.sql.tests.pandas.test_pandas_cogrouped_map import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)