| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| import datetime |
| import unittest |
| |
| from pyspark.sql.types import ( |
| Row, |
| ArrayType, |
| StringType, |
| IntegerType, |
| StructType, |
| StructField, |
| BooleanType, |
| DateType, |
| TimestampType, |
| TimestampNTZType, |
| FloatType, |
| DayTimeIntervalType, |
| ) |
| from pyspark.testing.sqlutils import ( |
| ReusedSQLTestCase, |
| have_pyarrow, |
| have_pandas, |
| pandas_requirement_message, |
| pyarrow_requirement_message, |
| ) |
| from pyspark.testing.utils import assertDataFrameEqual |
| |
| |
| class DataFrameCollectionTestsMixin: |
| def _to_pandas(self): |
| from datetime import datetime, date, timedelta |
| |
| schema = ( |
| StructType() |
| .add("a", IntegerType()) |
| .add("b", StringType()) |
| .add("c", BooleanType()) |
| .add("d", FloatType()) |
| .add("dt", DateType()) |
| .add("ts", TimestampType()) |
| .add("ts_ntz", TimestampNTZType()) |
| .add("dt_interval", DayTimeIntervalType()) |
| ) |
| data = [ |
| ( |
| 1, |
| "foo", |
| True, |
| 3.0, |
| date(1969, 1, 1), |
| datetime(1969, 1, 1, 1, 1, 1), |
| datetime(1969, 1, 1, 1, 1, 1), |
| timedelta(days=1), |
| ), |
| (2, "foo", True, 5.0, None, None, None, None), |
| ( |
| 3, |
| "bar", |
| False, |
| -1.0, |
| date(2012, 3, 3), |
| datetime(2012, 3, 3, 3, 3, 3), |
| datetime(2012, 3, 3, 3, 3, 3), |
| timedelta(hours=-1, milliseconds=421), |
| ), |
| ( |
| 4, |
| "bar", |
| False, |
| 6.0, |
| date(2100, 4, 4), |
| datetime(2100, 4, 4, 4, 4, 4), |
| datetime(2100, 4, 4, 4, 4, 4), |
| timedelta(microseconds=123), |
| ), |
| ] |
| df = self.spark.createDataFrame(data, schema) |
| return df.toPandas() |
| |
| @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore |
| def test_to_pandas(self): |
| import numpy as np |
| |
| pdf = self._to_pandas() |
| types = pdf.dtypes |
| self.assertEqual(types[0], np.int32) |
| self.assertEqual(types[1], object) |
| self.assertEqual(types[2], bool) |
| self.assertEqual(types[3], np.float32) |
| self.assertEqual(types[4], object) # datetime.date |
| self.assertEqual(types[5], "datetime64[ns]") |
| self.assertEqual(types[6], "datetime64[ns]") |
| self.assertEqual(types[7], "timedelta64[ns]") |
| |
| @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore |
| def test_to_pandas_with_duplicated_column_names(self): |
| for arrow_enabled in [False, True]: |
| with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): |
| self.check_to_pandas_with_duplicated_column_names() |
| |
| def check_to_pandas_with_duplicated_column_names(self): |
| import numpy as np |
| |
| sql = "select 1 v, 1 v" |
| df = self.spark.sql(sql) |
| pdf = df.toPandas() |
| types = pdf.dtypes |
| self.assertEqual(types.iloc[0], np.int32) |
| self.assertEqual(types.iloc[1], np.int32) |
| |
| @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore |
| def test_to_pandas_on_cross_join(self): |
| for arrow_enabled in [False, True]: |
| with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): |
| self.check_to_pandas_on_cross_join() |
| |
| def check_to_pandas_on_cross_join(self): |
| import numpy as np |
| |
| sql = """ |
| select t1.*, t2.* from ( |
| select explode(sequence(1, 3)) v |
| ) t1 left join ( |
| select explode(sequence(1, 3)) v |
| ) t2 |
| """ |
| with self.sql_conf({"spark.sql.crossJoin.enabled": True}): |
| df = self.spark.sql(sql) |
| pdf = df.toPandas() |
| types = pdf.dtypes |
| self.assertEqual(types.iloc[0], np.int32) |
| self.assertEqual(types.iloc[1], np.int32) |
| |
| @unittest.skipIf(have_pandas, "Required Pandas was found.") |
| def test_to_pandas_required_pandas_not_found(self): |
| with self.quiet(): |
| with self.assertRaisesRegex(ImportError, "Pandas >= .* must be installed"): |
| self._to_pandas() |
| |
| @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore |
| def test_to_pandas_avoid_astype(self): |
| import numpy as np |
| |
| schema = StructType().add("a", IntegerType()).add("b", StringType()).add("c", IntegerType()) |
| data = [(1, "foo", 16777220), (None, "bar", None)] |
| df = self.spark.createDataFrame(data, schema) |
| types = df.toPandas().dtypes |
| self.assertEqual(types[0], np.float64) # doesn't convert to np.int32 due to NaN value. |
| self.assertEqual(types[1], object) |
| self.assertEqual(types[2], np.float64) |
| |
| @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore |
| def test_to_pandas_from_empty_dataframe(self): |
| is_arrow_enabled = [True, False] |
| for value in is_arrow_enabled: |
| with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): |
| self.check_to_pandas_from_empty_dataframe() |
| |
| def check_to_pandas_from_empty_dataframe(self): |
| # SPARK-29188 test that toPandas() on an empty dataframe has the correct dtypes |
| # SPARK-30537 test that toPandas() on an empty dataframe has the correct dtypes |
| # when arrow is enabled |
| import numpy as np |
| |
| sql = """ |
| SELECT CAST(1 AS TINYINT) AS tinyint, |
| CAST(1 AS SMALLINT) AS smallint, |
| CAST(1 AS INT) AS int, |
| CAST(1 AS BIGINT) AS bigint, |
| CAST(0 AS FLOAT) AS float, |
| CAST(0 AS DOUBLE) AS double, |
| CAST(1 AS BOOLEAN) AS boolean, |
| CAST('foo' AS STRING) AS string, |
| CAST('2019-01-01' AS TIMESTAMP) AS timestamp, |
| CAST('2019-01-01' AS TIMESTAMP_NTZ) AS timestamp_ntz, |
| INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval |
| """ |
| dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes |
| dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes |
| self.assertTrue(np.all(dtypes_when_empty_df == dtypes_when_nonempty_df)) |
| |
| @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore |
| def test_to_pandas_from_null_dataframe(self): |
| is_arrow_enabled = [True, False] |
| for value in is_arrow_enabled: |
| with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): |
| self.check_to_pandas_from_null_dataframe() |
| |
| def check_to_pandas_from_null_dataframe(self): |
| # SPARK-29188 test that toPandas() on a dataframe with only nulls has correct dtypes |
| # SPARK-30537 test that toPandas() on a dataframe with only nulls has correct dtypes |
| # using arrow |
| import numpy as np |
| |
| sql = """ |
| SELECT CAST(NULL AS TINYINT) AS tinyint, |
| CAST(NULL AS SMALLINT) AS smallint, |
| CAST(NULL AS INT) AS int, |
| CAST(NULL AS BIGINT) AS bigint, |
| CAST(NULL AS FLOAT) AS float, |
| CAST(NULL AS DOUBLE) AS double, |
| CAST(NULL AS BOOLEAN) AS boolean, |
| CAST(NULL AS STRING) AS string, |
| CAST(NULL AS TIMESTAMP) AS timestamp, |
| CAST(NULL AS TIMESTAMP_NTZ) AS timestamp_ntz, |
| INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval |
| """ |
| pdf = self.spark.sql(sql).toPandas() |
| types = pdf.dtypes |
| self.assertEqual(types[0], np.float64) |
| self.assertEqual(types[1], np.float64) |
| self.assertEqual(types[2], np.float64) |
| self.assertEqual(types[3], np.float64) |
| self.assertEqual(types[4], np.float32) |
| self.assertEqual(types[5], np.float64) |
| self.assertEqual(types[6], object) |
| self.assertEqual(types[7], object) |
| self.assertTrue(np.can_cast(np.datetime64, types[8])) |
| self.assertTrue(np.can_cast(np.datetime64, types[9])) |
| self.assertTrue(np.can_cast(np.timedelta64, types[10])) |
| |
| @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore |
| def test_to_pandas_from_mixed_dataframe(self): |
| is_arrow_enabled = [True, False] |
| for value in is_arrow_enabled: |
| with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): |
| self.check_to_pandas_from_mixed_dataframe() |
| |
| def check_to_pandas_from_mixed_dataframe(self): |
| # SPARK-29188 test that toPandas() on a dataframe with some nulls has correct dtypes |
| # SPARK-30537 test that toPandas() on a dataframe with some nulls has correct dtypes |
| # using arrow |
| import numpy as np |
| |
| sql = """ |
| SELECT CAST(col1 AS TINYINT) AS tinyint, |
| CAST(col2 AS SMALLINT) AS smallint, |
| CAST(col3 AS INT) AS int, |
| CAST(col4 AS BIGINT) AS bigint, |
| CAST(col5 AS FLOAT) AS float, |
| CAST(col6 AS DOUBLE) AS double, |
| CAST(col7 AS BOOLEAN) AS boolean, |
| CAST(col8 AS STRING) AS string, |
| timestamp_seconds(col9) AS timestamp, |
| timestamp_seconds(col10) AS timestamp_ntz, |
| INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval |
| FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), |
| (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL) |
| """ |
| pdf_with_some_nulls = self.spark.sql(sql).toPandas() |
| pdf_with_only_nulls = self.spark.sql(sql).filter("tinyint is null").toPandas() |
| self.assertTrue(np.all(pdf_with_only_nulls.dtypes == pdf_with_some_nulls.dtypes)) |
| |
| @unittest.skipIf( |
| not have_pandas or not have_pyarrow, |
| pandas_requirement_message or pyarrow_requirement_message, |
| ) |
| def test_to_pandas_for_array_of_struct(self): |
| for is_arrow_enabled in [True, False]: |
| with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": is_arrow_enabled}): |
| self.check_to_pandas_for_array_of_struct(is_arrow_enabled) |
| |
| def check_to_pandas_for_array_of_struct(self, is_arrow_enabled): |
| # SPARK-38098: Support Array of Struct for Pandas UDFs and toPandas |
| import numpy as np |
| import pandas as pd |
| |
| df = self.spark.createDataFrame( |
| [[[("a", 2, 3.0), ("a", 2, 3.0)]], [[("b", 5, 6.0), ("b", 5, 6.0)]]], |
| "array_struct_col Array<struct<col1:string, col2:long, col3:double>>", |
| ) |
| |
| pdf = df.toPandas() |
| self.assertEqual(type(pdf), pd.DataFrame) |
| self.assertEqual(type(pdf["array_struct_col"]), pd.Series) |
| if is_arrow_enabled: |
| self.assertEqual(type(pdf["array_struct_col"][0]), np.ndarray) |
| else: |
| self.assertEqual(type(pdf["array_struct_col"][0]), list) |
| |
| @unittest.skipIf( |
| not have_pandas or not have_pyarrow, |
| pandas_requirement_message or pyarrow_requirement_message, |
| ) |
| def test_to_pandas_for_empty_df_with_nested_array_columns(self): |
| for arrow_enabled in [False, True]: |
| with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): |
| self.check_to_pandas_for_empty_df_with_nested_array_columns() |
| |
| def check_to_pandas_for_empty_df_with_nested_array_columns(self): |
| # SPARK-51112: Segfault must not occur when converting empty DataFrame with nested array |
| # columns to pandas DataFrame. |
| import pandas as pd |
| |
| df = self.spark.createDataFrame( |
| data=[], |
| schema=StructType( |
| [ |
| StructField( |
| name="b_int", |
| dataType=IntegerType(), |
| nullable=False, |
| ), |
| StructField( |
| name="b", |
| dataType=ArrayType(ArrayType(StringType(), True), True), |
| nullable=True, |
| ), |
| ] |
| ), |
| ) |
| expected_pdf = pd.DataFrame(columns=["b_int", "b"]) |
| assertDataFrameEqual(df.toPandas(), expected_pdf) |
| |
| def test_to_local_iterator(self): |
| df = self.spark.range(8, numPartitions=4) |
| expected = df.collect() |
| it = df.toLocalIterator() |
| self.assertEqual(expected, list(it)) |
| |
| # Test DataFrame with empty partition |
| df = self.spark.range(3, numPartitions=4) |
| it = df.toLocalIterator() |
| expected = df.collect() |
| self.assertEqual(expected, list(it)) |
| |
| def test_to_local_iterator_prefetch(self): |
| df = self.spark.range(8, numPartitions=4) |
| expected = df.collect() |
| it = df.toLocalIterator(prefetchPartitions=True) |
| self.assertEqual(expected, list(it)) |
| |
| def test_to_local_iterator_not_fully_consumed(self): |
| with self.quiet(): |
| self.check_to_local_iterator_not_fully_consumed() |
| |
| def check_to_local_iterator_not_fully_consumed(self): |
| # SPARK-23961: toLocalIterator throws exception when not fully consumed |
| # Create a DataFrame large enough so that write to socket will eventually block |
| df = self.spark.range(1 << 20, numPartitions=2) |
| it = df.toLocalIterator() |
| self.assertEqual(df.take(1)[0], next(it)) |
| it = None # remove iterator from scope, socket is closed when cleaned up |
| # Make sure normal df operations still work |
| result = [] |
| for i, row in enumerate(df.toLocalIterator()): |
| result.append(row) |
| if i == 7: |
| break |
| self.assertEqual(df.take(8), result) |
| |
| def test_collect_time(self): |
| import pandas as pd |
| |
| query = """ |
| SELECT * FROM VALUES |
| (TIME '12:34:56', 'a'), (TIME '22:56:01', 'b'), (NULL, 'c') |
| AS tab(t, i) |
| """ |
| |
| df = self.spark.sql(query) |
| |
| rows = df.collect() |
| self.assertEqual( |
| rows, |
| [ |
| Row(t=datetime.time(12, 34, 56), i="a"), |
| Row(t=datetime.time(22, 56, 1), i="b"), |
| Row(t=None, i="c"), |
| ], |
| ) |
| |
| pdf = df.toPandas() |
| self.assertTrue( |
| pdf.equals( |
| pd.DataFrame( |
| { |
| "t": [datetime.time(12, 34, 56), datetime.time(22, 56, 1), None], |
| "i": ["a", "b", "c"], |
| } |
| ) |
| ) |
| ) |
| |
| tbl = df.toArrow() |
| self.assertEqual( |
| [t.as_py() for t in tbl.column("t")], |
| [datetime.time(12, 34, 56), datetime.time(22, 56, 1), None], |
| ) |
| self.assertEqual( |
| [i.as_py() for i in tbl.column("i")], |
| ["a", "b", "c"], |
| ) |
| |
| |
| class DataFrameCollectionTests( |
| DataFrameCollectionTestsMixin, |
| ReusedSQLTestCase, |
| ): |
| pass |
| |
| |
| if __name__ == "__main__": |
| from pyspark.sql.tests.test_collection import * # noqa: F401 |
| |
| try: |
| import xmlrunner # type: ignore |
| |
| testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) |
| except ImportError: |
| testRunner = None |
| unittest.main(testRunner=testRunner, verbosity=2) |