| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| import os |
| import time |
| import unittest |
| import logging |
| |
| from pyspark.errors import PythonException |
| from pyspark.sql import Row |
| from pyspark.sql import functions as sf |
| from pyspark.testing.sqlutils import ( |
| ReusedSQLTestCase, |
| have_pyarrow, |
| pyarrow_requirement_message, |
| ) |
| from pyspark.testing.utils import assertDataFrameEqual |
| from pyspark.util import is_remote_only |
| |
| if have_pyarrow: |
| import pyarrow as pa |
| import pyarrow.compute as pc |
| |
| |
| @unittest.skipIf( |
| not have_pyarrow, |
| pyarrow_requirement_message, |
| ) |
| class CogroupedMapInArrowTestsMixin: |
| @property |
| def left(self): |
| return self.spark.range(0, 10, 2, 3).withColumn("v", sf.col("id") * 10) |
| |
| @property |
| def right(self): |
| return self.spark.range(0, 10, 3, 3).withColumn("v", sf.col("id") * 10) |
| |
| @property |
| def cogrouped(self): |
| grouped_left_df = self.left.groupBy((sf.col("id") / 4).cast("int")) |
| grouped_right_df = self.right.groupBy((sf.col("id") / 4).cast("int")) |
| return grouped_left_df.cogroup(grouped_right_df) |
| |
| @staticmethod |
| def apply_in_arrow_func(left, right): |
| assert isinstance(left, pa.Table) |
| assert isinstance(right, pa.Table) |
| assert left.schema.names == ["id", "v"] |
| assert right.schema.names == ["id", "v"] |
| |
| left_ids = left.to_pydict()["id"] |
| right_ids = right.to_pydict()["id"] |
| result = { |
| "metric": ["min", "max", "len", "sum"], |
| "left": [min(left_ids), max(left_ids), len(left_ids), sum(left_ids)], |
| "right": [min(right_ids), max(right_ids), len(right_ids), sum(right_ids)], |
| } |
| return pa.Table.from_pydict(result) |
| |
| @staticmethod |
| def apply_in_arrow_with_key_func(key_column): |
| def func(key, left, right): |
| assert isinstance(key, tuple) |
| assert all(isinstance(scalar, pa.Scalar) for scalar in key) |
| if key_column: |
| assert all( |
| (pc.divide(k, pa.scalar(4)).cast(pa.int32()),) == key |
| for table in [left, right] |
| for k in table.column(key_column) |
| ) |
| return CogroupedMapInArrowTestsMixin.apply_in_arrow_func(left, right) |
| |
| return func |
| |
| @staticmethod |
| def apply_in_pandas_with_key_func(key_column): |
| def func(key, left, right): |
| return CogroupedMapInArrowTestsMixin.apply_in_arrow_with_key_func(key_column)( |
| tuple(pa.scalar(k) for k in key), |
| pa.Table.from_pandas(left), |
| pa.Table.from_pandas(right), |
| ).to_pandas() |
| |
| return func |
| |
| def do_test_apply_in_arrow(self, cogrouped_df, key_column="id"): |
| schema = "metric string, left long, right long" |
| |
| # compare with result of applyInPandas |
| expected = cogrouped_df.applyInPandas( |
| CogroupedMapInArrowTestsMixin.apply_in_pandas_with_key_func(key_column), schema |
| ) |
| |
| # apply in arrow without key |
| actual = cogrouped_df.applyInArrow( |
| CogroupedMapInArrowTestsMixin.apply_in_arrow_func, schema |
| ).collect() |
| self.assertEqual(actual, expected.collect()) |
| |
| # apply in arrow with key |
| actual2 = cogrouped_df.applyInArrow( |
| CogroupedMapInArrowTestsMixin.apply_in_arrow_with_key_func(key_column), schema |
| ).collect() |
| self.assertEqual(actual2, expected.collect()) |
| |
| def test_apply_in_arrow(self): |
| self.do_test_apply_in_arrow(self.cogrouped) |
| |
| def test_apply_in_arrow_empty_groupby(self): |
| grouped_left_df = self.left.groupBy() |
| grouped_right_df = self.right.groupBy() |
| cogrouped_df = grouped_left_df.cogroup(grouped_right_df) |
| self.do_test_apply_in_arrow(cogrouped_df, key_column=None) |
| |
| def test_apply_in_arrow_not_returning_arrow_table(self): |
| def func(key, left, right): |
| return key |
| |
| with self.quiet(): |
| with self.assertRaisesRegex( |
| PythonException, |
| "Return type of the user-defined function should be pyarrow.Table, but is tuple", |
| ): |
| self.cogrouped.applyInArrow(func, schema="id long").collect() |
| |
| def test_apply_in_arrow_returning_wrong_types(self): |
| for schema, expected in [ |
| ("id integer, v long", "column 'id' \\(expected int32, actual int64\\)"), |
| ( |
| "id integer, v integer", |
| "column 'id' \\(expected int32, actual int64\\), " |
| "column 'v' \\(expected int32, actual int64\\)", |
| ), |
| ("id long, v integer", "column 'v' \\(expected int32, actual int64\\)"), |
| ("id long, v string", "column 'v' \\(expected string, actual int64\\)"), |
| ]: |
| with self.subTest(schema=schema): |
| with self.quiet(): |
| with self.assertRaisesRegex( |
| PythonException, |
| f"Columns do not match in their data type: {expected}", |
| ): |
| self.cogrouped.applyInArrow( |
| lambda left, right: left, schema=schema |
| ).collect() |
| |
| def test_apply_in_arrow_returning_wrong_types_positional_assignment(self): |
| for schema, expected in [ |
| ("a integer, b long", "column 'a' \\(expected int32, actual int64\\)"), |
| ( |
| "a integer, b integer", |
| "column 'a' \\(expected int32, actual int64\\), " |
| "column 'b' \\(expected int32, actual int64\\)", |
| ), |
| ("a long, b int", "column 'b' \\(expected int32, actual int64\\)"), |
| ("a long, b string", "column 'b' \\(expected string, actual int64\\)"), |
| ]: |
| with self.subTest(schema=schema): |
| with self.sql_conf( |
| {"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False} |
| ): |
| with self.quiet(): |
| with self.assertRaisesRegex( |
| PythonException, |
| f"Columns do not match in their data type: {expected}", |
| ): |
| self.cogrouped.applyInArrow( |
| lambda left, right: left, schema=schema |
| ).collect() |
| |
| def test_apply_in_arrow_returning_wrong_column_names(self): |
| def stats(key, left, right): |
| # returning three columns |
| return pa.Table.from_pydict( |
| { |
| "id": [key[0].as_py()], |
| "v": [pc.mean(left.column("v")).as_py()], |
| "v2": [pc.stddev(right.column("v")).as_py()], |
| } |
| ) |
| |
| with self.quiet(): |
| with self.assertRaisesRegex( |
| PythonException, |
| "Column names of the returned pyarrow.Table do not match specified schema. " |
| "Missing: m. Unexpected: v, v2.", |
| ): |
| # stats returns three columns while here we set schema with two columns |
| self.cogrouped.applyInArrow(stats, schema="id long, m double").collect() |
| |
| def test_apply_in_arrow_returning_empty_dataframe(self): |
| def odd_means(key, left, right): |
| if key[0].as_py() == 0: |
| return pa.table([]) |
| else: |
| return pa.Table.from_pydict( |
| { |
| "id": [key[0].as_py()], |
| "m": [pc.mean(left.column("v")).as_py()], |
| "n": [pc.mean(right.column("v")).as_py()], |
| } |
| ) |
| |
| schema = "id long, m double, n double" |
| actual = self.cogrouped.applyInArrow(odd_means, schema=schema).sort("id").collect() |
| expected = [Row(id=1, m=50.0, n=60.0), Row(id=2, m=80.0, n=90.0)] |
| self.assertEqual(expected, actual) |
| |
| def test_apply_in_arrow_returning_empty_dataframe_and_wrong_column_names(self): |
| def odd_means(key, left, _): |
| if key[0].as_py() % 2 == 0: |
| return pa.table([[]], names=["id"]) |
| else: |
| return pa.Table.from_pydict( |
| {"id": [key[0].as_py()], "m": [pc.mean(left.column("v")).as_py()]} |
| ) |
| |
| with self.quiet(): |
| with self.assertRaisesRegex( |
| PythonException, |
| "Column names of the returned pyarrow.Table do not match specified schema. " |
| "Missing: m.", |
| ): |
| # stats returns one column for even keys while here we set schema with two columns |
| self.cogrouped.applyInArrow(odd_means, schema="id long, m double").collect() |
| |
| def test_apply_in_arrow_column_order(self): |
| df = self.left |
| expected = df.select(df.id, (df.v * 3).alias("u"), df.v).collect() |
| |
| # Function returns a table with required column names but different order |
| def change_col_order(left, _): |
| return left.append_column("u", pc.multiply(left.column("v"), 3)) |
| |
| # The result should assign columns by name from the table |
| result = ( |
| self.cogrouped.applyInArrow(change_col_order, "id long, u long, v long") |
| .sort("id", "v") |
| .select("id", "u", "v") |
| .collect() |
| ) |
| self.assertEqual(expected, result) |
| |
| def test_positional_assignment_conf(self): |
| with self.sql_conf( |
| {"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False} |
| ): |
| |
| def foo(left, right): |
| return pa.Table.from_pydict({"x": ["hi"], "y": [1]}) |
| |
| result = self.cogrouped.applyInArrow(foo, "a string, b long").select("a", "b").collect() |
| for r in result: |
| self.assertEqual(r.a, "hi") |
| self.assertEqual(r.b, 1) |
| |
| def test_with_local_data(self): |
| df1 = self.spark.createDataFrame( |
| [(1, 1.0, "a"), (2, 2.0, "b"), (1, 3.0, "c"), (2, 4.0, "d")], ("id", "v1", "v2") |
| ) |
| df2 = self.spark.createDataFrame([(1, "x"), (2, "y"), (1, "z")], ("id", "v3")) |
| |
| def summarize(left, right): |
| return pa.Table.from_pydict( |
| { |
| "left_rows": [left.num_rows], |
| "left_columns": [left.num_columns], |
| "right_rows": [right.num_rows], |
| "right_columns": [right.num_columns], |
| } |
| ) |
| |
| df = ( |
| df1.groupby("id") |
| .cogroup(df2.groupby("id")) |
| .applyInArrow( |
| summarize, |
| schema="left_rows long, left_columns long, right_rows long, right_columns long", |
| ) |
| ) |
| |
| self.assertEqual( |
| df._show_string(), |
| "+---------+------------+----------+-------------+\n" |
| "|left_rows|left_columns|right_rows|right_columns|\n" |
| "+---------+------------+----------+-------------+\n" |
| "| 2| 3| 2| 2|\n" |
| "| 2| 3| 1| 2|\n" |
| "+---------+------------+----------+-------------+\n", |
| ) |
| |
| def test_self_join(self): |
| df = self.spark.createDataFrame([(1, 1)], ("k", "v")) |
| |
| def arrow_func(key, left, right): |
| return pa.Table.from_pydict({"x": [2], "y": [2]}) |
| |
| df2 = df.groupby("k").cogroup(df.groupby("k")).applyInArrow(arrow_func, "x long, y long") |
| |
| self.assertEqual(df2.join(df2).count(), 1) |
| |
| def test_arrow_batch_slicing(self): |
| m, n = 100000, 10000 |
| |
| df1 = self.spark.range(m).select((sf.col("id") % 2).alias("key"), sf.col("id").alias("v")) |
| cols = {f"col_{i}": sf.col("v") + i for i in range(10)} |
| df1 = df1.withColumns(cols) |
| |
| df2 = self.spark.range(n).select((sf.col("id") % 4).alias("key"), sf.col("id").alias("v")) |
| cols = {f"col_{i}": sf.col("v") + i for i in range(20)} |
| df2 = df2.withColumns(cols) |
| |
| def summarize(key, left, right): |
| assert len(left) == m / 2 or len(left) == 0, len(left) |
| assert len(right) == n / 4, len(right) |
| return pa.Table.from_pydict( |
| { |
| "key": [key[0].as_py()], |
| "left_rows": [left.num_rows], |
| "left_columns": [left.num_columns], |
| "right_rows": [right.num_rows], |
| "right_columns": [right.num_columns], |
| } |
| ) |
| |
| schema = "key long, left_rows long, left_columns long, right_rows long, right_columns long" |
| |
| expected = [ |
| Row(key=0, left_rows=m / 2, left_columns=12, right_rows=n / 4, right_columns=22), |
| Row(key=1, left_rows=m / 2, left_columns=12, right_rows=n / 4, right_columns=22), |
| Row(key=2, left_rows=0, left_columns=12, right_rows=n / 4, right_columns=22), |
| Row(key=3, left_rows=0, left_columns=12, right_rows=n / 4, right_columns=22), |
| ] |
| |
| for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 4096), (1000, 4096)]: |
| with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes): |
| with self.sql_conf( |
| { |
| "spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords, |
| "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes, |
| } |
| ): |
| result = ( |
| df1.groupby("key") |
| .cogroup(df2.groupby("key")) |
| .applyInArrow(summarize, schema=schema) |
| .sort("key") |
| .collect() |
| ) |
| |
| self.assertEqual(expected, result) |
| |
| def test_negative_and_zero_batch_size(self): |
| for batch_size in [0, -1]: |
| with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}): |
| CogroupedMapInArrowTestsMixin.test_apply_in_arrow(self) |
| |
| @unittest.skipIf(is_remote_only(), "Requires JVM access") |
| def test_cogroup_apply_in_arrow_with_logging(self): |
| import pyarrow as pa |
| |
| def func_with_logging(left, right): |
| assert isinstance(left, pa.Table) |
| assert isinstance(right, pa.Table) |
| logger = logging.getLogger("test_arrow_cogrouped_map") |
| logger.warning( |
| "arrow cogrouped map: " |
| + f"{dict(v1=left['v1'].to_pylist(), v2=right['v2'].to_pylist())}" |
| ) |
| return left.join(right, keys="id", join_type="inner") |
| |
| left_df = self.spark.createDataFrame([(1, 10), (2, 20), (1, 30)], ["id", "v1"]) |
| right_df = self.spark.createDataFrame([(1, 100), (2, 200), (1, 300)], ["id", "v2"]) |
| |
| grouped_left = left_df.groupBy("id") |
| grouped_right = right_df.groupBy("id") |
| cogrouped_df = grouped_left.cogroup(grouped_right) |
| |
| with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): |
| assertDataFrameEqual( |
| cogrouped_df.applyInArrow(func_with_logging, "id long, v1 long, v2 long"), |
| [Row(id=1, v1=v1, v2=v2) for v1 in [10, 30] for v2 in [100, 300]] |
| + [Row(id=2, v1=20, v2=200)], |
| ) |
| |
| logs = self.spark.tvf.python_worker_logs() |
| |
| assertDataFrameEqual( |
| logs.select("level", "msg", "context", "logger"), |
| [ |
| Row( |
| level="WARNING", |
| msg=f"arrow cogrouped map: {dict(v1=v1, v2=v2)}", |
| context={"func_name": func_with_logging.__name__}, |
| logger="test_arrow_cogrouped_map", |
| ) |
| for v1, v2 in [([10, 30], [100, 300]), ([20], [200])] |
| ], |
| ) |
| |
| |
| class CogroupedMapInArrowTests(CogroupedMapInArrowTestsMixin, ReusedSQLTestCase): |
| @classmethod |
| def setUpClass(cls): |
| ReusedSQLTestCase.setUpClass() |
| |
| # Synchronize default timezone between Python and Java |
| cls.tz_prev = os.environ.get("TZ", None) # save current tz if set |
| tz = "America/Los_Angeles" |
| os.environ["TZ"] = tz |
| time.tzset() |
| |
| cls.sc.environment["TZ"] = tz |
| cls.spark.conf.set("spark.sql.session.timeZone", tz) |
| |
| @classmethod |
| def tearDownClass(cls): |
| del os.environ["TZ"] |
| if cls.tz_prev is not None: |
| os.environ["TZ"] = cls.tz_prev |
| time.tzset() |
| ReusedSQLTestCase.tearDownClass() |
| |
| |
| if __name__ == "__main__": |
| from pyspark.testing import main |
| |
| main() |