blob: 3dc243ae6f828ab22adde23a259584202b6dd7ad [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
from typing import cast
from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf, sum
from pyspark.sql.types import DoubleType, StructType, StructField, Row
from pyspark.sql.window import Window
from pyspark.errors import IllegalArgumentException, PythonException
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.testing.utils import QuietTest
if have_pandas:
import pandas as pd
from pandas.testing import assert_frame_equal
if have_pyarrow:
import pyarrow as pa # noqa: F401
@unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
class CogroupedApplyInPandasTestsMixin(ReusedSQLTestCase):
@property
def data1(self):
return (
self.spark.range(10)
.toDF("id")
.withColumn("ks", array([lit(i) for i in range(20, 30)]))
.withColumn("k", explode(col("ks")))
.withColumn("v", col("k") * 10)
.drop("ks")
)
@property
def data2(self):
return (
self.spark.range(10)
.toDF("id")
.withColumn("ks", array([lit(i) for i in range(20, 30)]))
.withColumn("k", explode(col("ks")))
.withColumn("v2", col("k") * 100)
.drop("ks")
)
def test_simple(self):
self._test_merge(self.data1, self.data2)
def test_left_group_empty(self):
left = self.data1.where(col("id") % 2 == 0)
self._test_merge(left, self.data2)
def test_right_group_empty(self):
right = self.data2.where(col("id") % 2 == 0)
self._test_merge(self.data1, right)
def test_different_schemas(self):
right = self.data2.withColumn("v3", lit("a"))
self._test_merge(self.data1, right, "id long, k int, v int, v2 int, v3 string")
def test_different_keys(self):
left = self.data1
right = self.data2
def merge_pandas(lft, rgt):
return pd.merge(lft.rename(columns={"id2": "id"}), rgt, on=["id", "k"])
result = (
left.withColumnRenamed("id", "id2")
.groupby("id2")
.cogroup(right.groupby("id"))
.applyInPandas(merge_pandas, "id long, k int, v int, v2 int")
.sort(["id", "k"])
.toPandas()
)
left = left.toPandas()
right = right.toPandas()
expected = pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", "k"])
assert_frame_equal(expected, result)
def test_complex_group_by(self):
left = pd.DataFrame.from_dict({"id": [1, 2, 3], "k": [5, 6, 7], "v": [9, 10, 11]})
right = pd.DataFrame.from_dict({"id": [11, 12, 13], "k": [5, 6, 7], "v2": [90, 100, 110]})
left_gdf = self.spark.createDataFrame(left).groupby(col("id") % 2 == 0)
right_gdf = self.spark.createDataFrame(right).groupby(col("id") % 2 == 0)
def merge_pandas(lft, rgt):
return pd.merge(lft[["k", "v"]], rgt[["k", "v2"]], on=["k"])
result = (
left_gdf.cogroup(right_gdf)
.applyInPandas(merge_pandas, "k long, v long, v2 long")
.sort(["k"])
.toPandas()
)
expected = pd.DataFrame.from_dict({"k": [5, 6, 7], "v": [9, 10, 11], "v2": [90, 100, 110]})
assert_frame_equal(expected, result)
def test_empty_group_by(self):
left = self.data1
right = self.data2
def merge_pandas(lft, rgt):
return pd.merge(lft, rgt, on=["id", "k"])
result = (
left.groupby()
.cogroup(right.groupby())
.applyInPandas(merge_pandas, "id long, k int, v int, v2 int")
.sort(["id", "k"])
.toPandas()
)
left = left.toPandas()
right = right.toPandas()
expected = pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", "k"])
assert_frame_equal(expected, result)
def test_different_group_key_cardinality(self):
left = self.data1
right = self.data2
def merge_pandas(lft, _):
return lft
with QuietTest(self.sc):
with self.assertRaisesRegex(
IllegalArgumentException,
"requirement failed: Cogroup keys must have same size: 2 != 1",
):
(left.groupby("id", "k").cogroup(right.groupby("id"))).applyInPandas(
merge_pandas, "id long, k int, v int"
)
def test_apply_in_pandas_not_returning_pandas_dataframe(self):
left = self.data1
right = self.data2
def merge_pandas(lft, rgt):
return lft.size + rgt.size
with QuietTest(self.sc):
with self.assertRaisesRegex(
PythonException,
"Return type of the user-defined function should be pandas.DataFrame, "
"but is <class 'numpy.int64'>",
):
(
left.groupby("id")
.cogroup(right.groupby("id"))
.applyInPandas(merge_pandas, "id long, k int, v int, v2 int")
.collect()
)
def test_apply_in_pandas_returning_wrong_number_of_columns(self):
left = self.data1
right = self.data2
def merge_pandas(lft, rgt):
if 0 in lft["id"] and lft["id"][0] % 2 == 0:
lft["add"] = 0
if 0 in rgt["id"] and rgt["id"][0] % 3 == 0:
rgt["more"] = 1
return pd.merge(lft, rgt, on=["id", "k"])
with QuietTest(self.sc):
with self.assertRaisesRegex(
PythonException,
"Number of columns of the returned pandas.DataFrame "
"doesn't match specified schema. Expected: 4 Actual: 6",
):
(
# merge_pandas returns two columns for even keys while we set schema to four
left.groupby("id")
.cogroup(right.groupby("id"))
.applyInPandas(merge_pandas, "id long, k int, v int, v2 int")
.collect()
)
def test_apply_in_pandas_returning_empty_dataframe(self):
left = self.data1
right = self.data2
def merge_pandas(lft, rgt):
if 0 in lft["id"] and lft["id"][0] % 2 == 0:
return pd.DataFrame([])
if 0 in rgt["id"] and rgt["id"][0] % 3 == 0:
return pd.DataFrame([])
return pd.merge(lft, rgt, on=["id", "k"])
result = (
left.groupby("id")
.cogroup(right.groupby("id"))
.applyInPandas(merge_pandas, "id long, k int, v int, v2 int")
.sort(["id", "k"])
.toPandas()
)
left = left.toPandas()
right = right.toPandas()
expected = pd.merge(
left[left["id"] % 2 != 0], right[right["id"] % 3 != 0], on=["id", "k"]
).sort_values(by=["id", "k"])
assert_frame_equal(expected, result)
def test_apply_in_pandas_returning_empty_dataframe_and_wrong_number_of_columns(self):
left = self.data1
right = self.data2
def merge_pandas(lft, rgt):
if 0 in lft["id"] and lft["id"][0] % 2 == 0:
return pd.DataFrame([], columns=["id", "k"])
return pd.merge(lft, rgt, on=["id", "k"])
with QuietTest(self.sc):
with self.assertRaisesRegex(
PythonException,
"Number of columns of the returned pandas.DataFrame doesn't "
"match specified schema. Expected: 4 Actual: 2",
):
(
# merge_pandas returns two columns for even keys while we set schema to four
left.groupby("id")
.cogroup(right.groupby("id"))
.applyInPandas(merge_pandas, "id long, k int, v int, v2 int")
.collect()
)
def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self):
df = self.spark.range(0, 10).toDF("v1")
df = df.withColumn("v2", udf(lambda x: x + 1, "int")(df["v1"])).withColumn(
"v3", pandas_udf(lambda x: x + 2, "int")(df["v1"])
)
result = (
df.groupby()
.cogroup(df.groupby())
.applyInPandas(
lambda x, y: pd.DataFrame([(x.sum().sum(), y.sum().sum())]), "sum1 int, sum2 int"
)
.collect()
)
self.assertEqual(result[0]["sum1"], 165)
self.assertEqual(result[0]["sum2"], 165)
def test_with_key_left(self):
self._test_with_key(self.data1, self.data1, isLeft=True)
def test_with_key_right(self):
self._test_with_key(self.data1, self.data1, isLeft=False)
def test_with_key_left_group_empty(self):
left = self.data1.where(col("id") % 2 == 0)
self._test_with_key(left, self.data1, isLeft=True)
def test_with_key_right_group_empty(self):
right = self.data1.where(col("id") % 2 == 0)
self._test_with_key(self.data1, right, isLeft=False)
def test_with_key_complex(self):
def left_assign_key(key, lft, _):
return lft.assign(key=key[0])
result = (
self.data1.groupby(col("id") % 2 == 0)
.cogroup(self.data2.groupby(col("id") % 2 == 0))
.applyInPandas(left_assign_key, "id long, k int, v int, key boolean")
.sort(["id", "k"])
.toPandas()
)
expected = self.data1.toPandas()
expected = expected.assign(key=expected.id % 2 == 0)
assert_frame_equal(expected, result)
def test_wrong_return_type(self):
# Test that we get a sensible exception invalid values passed to apply
left = self.data1
right = self.data2
with QuietTest(self.sc):
with self.assertRaisesRegex(
NotImplementedError, "Invalid return type.*ArrayType.*TimestampType"
):
left.groupby("id").cogroup(right.groupby("id")).applyInPandas(
lambda l, r: l, "id long, v array<timestamp>"
)
def test_wrong_args(self):
left = self.data1
right = self.data2
with self.assertRaisesRegex(ValueError, "Invalid function"):
left.groupby("id").cogroup(right.groupby("id")).applyInPandas(
lambda: 1, StructType([StructField("d", DoubleType())])
)
def test_case_insensitive_grouping_column(self):
# SPARK-31915: case-insensitive grouping column should work.
df1 = self.spark.createDataFrame([(1, 1)], ("column", "value"))
row = (
df1.groupby("ColUmn")
.cogroup(df1.groupby("COLUMN"))
.applyInPandas(lambda r, l: r + l, "column long, value long")
.first()
)
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
df2 = self.spark.createDataFrame([(1, 1)], ("column", "value"))
row = (
df1.groupby("ColUmn")
.cogroup(df2.groupby("COLUMN"))
.applyInPandas(lambda r, l: r + l, "column long, value long")
.first()
)
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
def test_self_join(self):
# SPARK-34319: self-join with FlatMapCoGroupsInPandas
df = self.spark.createDataFrame([(1, 1)], ("column", "value"))
row = (
df.groupby("ColUmn")
.cogroup(df.groupby("COLUMN"))
.applyInPandas(lambda r, l: r + l, "column long, value long")
)
row = row.join(row).first()
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
def test_with_window_function(self):
# SPARK-42168: a window function with same partition keys but differing key order
ids = 2
days = 100
vals = 10000
parts = 10
id_df = self.spark.range(ids)
day_df = self.spark.range(days).withColumnRenamed("id", "day")
vals_df = self.spark.range(vals).withColumnRenamed("id", "value")
df = id_df.join(day_df).join(vals_df)
left_df = df.withColumnRenamed("value", "left").repartition(parts).cache()
# SPARK-42132: this bug requires us to alias all columns from df here
right_df = (
df.select(col("id").alias("id"), col("day").alias("day"), col("value").alias("right"))
.repartition(parts)
.cache()
)
# note the column order is different to the groupBy("id", "day") column order below
window = Window.partitionBy("day", "id")
left_grouped_df = left_df.groupBy("id", "day")
right_grouped_df = right_df.withColumn("day_sum", sum(col("day")).over(window)).groupBy(
"id", "day"
)
def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
return pd.DataFrame(
[
{
"id": left["id"][0]
if not left.empty
else (right["id"][0] if not right.empty else None),
"day": left["day"][0]
if not left.empty
else (right["day"][0] if not right.empty else None),
"lefts": len(left.index),
"rights": len(right.index),
}
]
)
df = left_grouped_df.cogroup(right_grouped_df).applyInPandas(
cogroup, schema="id long, day long, lefts integer, rights integer"
)
actual = df.orderBy("id", "day").take(days)
self.assertEqual(actual, [Row(0, day, vals, vals) for day in range(days)])
@staticmethod
def _test_with_key(left, right, isLeft):
def right_assign_key(key, lft, rgt):
return lft.assign(key=key[0]) if isLeft else rgt.assign(key=key[0])
result = (
left.groupby("id")
.cogroup(right.groupby("id"))
.applyInPandas(right_assign_key, "id long, k int, v int, key long")
.toPandas()
)
expected = left.toPandas() if isLeft else right.toPandas()
expected = expected.assign(key=expected.id)
assert_frame_equal(expected, result)
@staticmethod
def _test_merge(left, right, output_schema="id long, k int, v int, v2 int"):
def merge_pandas(lft, rgt):
return pd.merge(lft, rgt, on=["id", "k"])
result = (
left.groupby("id")
.cogroup(right.groupby("id"))
.applyInPandas(merge_pandas, output_schema)
.sort(["id", "k"])
.toPandas()
)
left = left.toPandas()
right = right.toPandas()
expected = pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", "k"])
assert_frame_equal(expected, result)
class CogroupedMapInPandasTests(CogroupedApplyInPandasTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.sql.tests.pandas.test_pandas_cogrouped_map import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)