| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import ctypes |
| import datetime |
| import os |
| import re |
| import threading |
| import time |
| from typing import Any |
| |
| import pyarrow as pa |
| import pyarrow.parquet as pq |
| import pytest |
| from datafusion import ( |
| DataFrame, |
| ParquetColumnOptions, |
| ParquetWriterOptions, |
| SessionContext, |
| WindowFrame, |
| column, |
| literal, |
| ) |
| from datafusion import ( |
| functions as f, |
| ) |
| from datafusion.dataframe_formatter import ( |
| DataFrameHtmlFormatter, |
| configure_formatter, |
| get_formatter, |
| reset_formatter, |
| ) |
| from datafusion.expr import Window |
| from pyarrow.csv import write_csv |
| |
| MB = 1024 * 1024 |
| |
| |
| @pytest.fixture |
| def ctx(): |
| return SessionContext() |
| |
| |
| @pytest.fixture |
| def df(): |
| ctx = SessionContext() |
| |
| # create a RecordBatch and a new DataFrame from it |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])], |
| names=["a", "b", "c"], |
| ) |
| |
| return ctx.from_arrow(batch) |
| |
| |
| @pytest.fixture |
| def large_df(): |
| ctx = SessionContext() |
| |
| rows = 100000 |
| data = { |
| "a": list(range(rows)), |
| "b": [f"s-{i}" for i in range(rows)], |
| "c": [float(i + 0.1) for i in range(rows)], |
| } |
| batch = pa.record_batch(data) |
| |
| return ctx.from_arrow(batch) |
| |
| |
| @pytest.fixture |
| def struct_df(): |
| ctx = SessionContext() |
| |
| # create a RecordBatch and a new DataFrame from it |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([{"c": 1}, {"c": 2}, {"c": 3}]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| |
| return ctx.create_dataframe([[batch]]) |
| |
| |
| @pytest.fixture |
| def nested_df(): |
| ctx = SessionContext() |
| |
| # create a RecordBatch and a new DataFrame from it |
| # Intentionally make each array of different length |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([[1], [2, 3], [4, 5, 6], None]), pa.array([7, 8, 9, 10])], |
| names=["a", "b"], |
| ) |
| |
| return ctx.create_dataframe([[batch]]) |
| |
| |
| @pytest.fixture |
| def aggregate_df(): |
| ctx = SessionContext() |
| ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv") |
| return ctx.sql("select c1, sum(c2) from test group by c1") |
| |
| |
| @pytest.fixture |
| def partitioned_df(): |
| ctx = SessionContext() |
| |
| # create a RecordBatch and a new DataFrame from it |
| batch = pa.RecordBatch.from_arrays( |
| [ |
| pa.array([0, 1, 2, 3, 4, 5, 6]), |
| pa.array([7, None, 7, 8, 9, None, 9]), |
| pa.array(["A", "A", "A", "A", "B", "B", "B"]), |
| ], |
| names=["a", "b", "c"], |
| ) |
| |
| return ctx.create_dataframe([[batch]]) |
| |
| |
| @pytest.fixture |
| def clean_formatter_state(): |
| """Reset the HTML formatter after each test.""" |
| reset_formatter() |
| |
| |
| @pytest.fixture |
| def null_df(): |
| """Create a DataFrame with null values of different types.""" |
| ctx = SessionContext() |
| |
| # Create a RecordBatch with nulls across different types |
| batch = pa.RecordBatch.from_arrays( |
| [ |
| pa.array([1, None, 3, None], type=pa.int64()), |
| pa.array([4.5, 6.7, None, None], type=pa.float64()), |
| pa.array(["a", None, "c", None], type=pa.string()), |
| pa.array([True, None, False, None], type=pa.bool_()), |
| pa.array( |
| [10957, None, 18993, None], type=pa.date32() |
| ), # 2000-01-01, null, 2022-01-01, null |
| pa.array( |
| [946684800000, None, 1640995200000, None], type=pa.date64() |
| ), # 2000-01-01, null, 2022-01-01, null |
| ], |
| names=[ |
| "int_col", |
| "float_col", |
| "str_col", |
| "bool_col", |
| "date32_col", |
| "date64_col", |
| ], |
| ) |
| |
| return ctx.create_dataframe([[batch]]) |
| |
| |
| # custom style for testing with html formatter |
| class CustomStyleProvider: |
| def get_cell_style(self) -> str: |
| return ( |
| "background-color: #f5f5f5; color: #333; padding: 8px; border: " |
| "1px solid #ddd;" |
| ) |
| |
| def get_header_style(self) -> str: |
| return ( |
| "background-color: #4285f4; color: white; font-weight: bold; " |
| "padding: 10px; border: 1px solid #3367d6;" |
| ) |
| |
| |
| def count_table_rows(html_content: str) -> int: |
| """Count the number of table rows in HTML content. |
| Args: |
| html_content: HTML string to analyze |
| Returns: |
| Number of table rows found (number of <tr> tags) |
| """ |
| return len(re.findall(r"<tr", html_content)) |
| |
| |
| def test_select(df): |
| df_1 = df.select( |
| column("a") + column("b"), |
| column("a") - column("b"), |
| ) |
| |
| # execute and collect the first (and only) batch |
| result = df_1.collect()[0] |
| |
| assert result.column(0) == pa.array([5, 7, 9]) |
| assert result.column(1) == pa.array([-3, -3, -3]) |
| |
| df_2 = df.select("b", "a") |
| |
| # execute and collect the first (and only) batch |
| result = df_2.collect()[0] |
| |
| assert result.column(0) == pa.array([4, 5, 6]) |
| assert result.column(1) == pa.array([1, 2, 3]) |
| |
| |
| def test_select_mixed_expr_string(df): |
| df = df.select(column("b"), "a") |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert result.column(0) == pa.array([4, 5, 6]) |
| assert result.column(1) == pa.array([1, 2, 3]) |
| |
| |
| def test_filter(df): |
| df1 = df.filter(column("a") > literal(2)).select( |
| column("a") + column("b"), |
| column("a") - column("b"), |
| ) |
| |
| # execute and collect the first (and only) batch |
| result = df1.collect()[0] |
| |
| assert result.column(0) == pa.array([9]) |
| assert result.column(1) == pa.array([-3]) |
| |
| df.show() |
| # verify that if there is no filter applied, internal dataframe is unchanged |
| df2 = df.filter() |
| assert df.df == df2.df |
| |
| df3 = df.filter(column("a") > literal(1), column("b") != literal(6)) |
| result = df3.collect()[0] |
| |
| assert result.column(0) == pa.array([2]) |
| assert result.column(1) == pa.array([5]) |
| assert result.column(2) == pa.array([5]) |
| |
| |
| def test_sort(df): |
| df = df.sort(column("b").sort(ascending=False)) |
| |
| table = pa.Table.from_batches(df.collect()) |
| expected = {"a": [3, 2, 1], "b": [6, 5, 4], "c": [8, 5, 8]} |
| |
| assert table.to_pydict() == expected |
| |
| |
| def test_drop(df): |
| df = df.drop("c") |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert df.schema().names == ["a", "b"] |
| assert result.column(0) == pa.array([1, 2, 3]) |
| assert result.column(1) == pa.array([4, 5, 6]) |
| |
| |
| def test_limit(df): |
| df = df.limit(1) |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert len(result.column(0)) == 1 |
| assert len(result.column(1)) == 1 |
| |
| |
| def test_limit_with_offset(df): |
| # only 3 rows, but limit past the end to ensure that offset is working |
| df = df.limit(5, offset=2) |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert len(result.column(0)) == 1 |
| assert len(result.column(1)) == 1 |
| |
| |
| def test_head(df): |
| df = df.head(1) |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert result.column(0) == pa.array([1]) |
| assert result.column(1) == pa.array([4]) |
| assert result.column(2) == pa.array([8]) |
| |
| |
| def test_tail(df): |
| df = df.tail(1) |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert result.column(0) == pa.array([3]) |
| assert result.column(1) == pa.array([6]) |
| assert result.column(2) == pa.array([8]) |
| |
| |
| def test_with_column(df): |
| df = df.with_column("c", column("a") + column("b")) |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert result.schema.field(0).name == "a" |
| assert result.schema.field(1).name == "b" |
| assert result.schema.field(2).name == "c" |
| |
| assert result.column(0) == pa.array([1, 2, 3]) |
| assert result.column(1) == pa.array([4, 5, 6]) |
| assert result.column(2) == pa.array([5, 7, 9]) |
| |
| |
| def test_with_columns(df): |
| df = df.with_columns( |
| (column("a") + column("b")).alias("c"), |
| (column("a") + column("b")).alias("d"), |
| [ |
| (column("a") + column("b")).alias("e"), |
| (column("a") + column("b")).alias("f"), |
| ], |
| g=(column("a") + column("b")), |
| ) |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert result.schema.field(0).name == "a" |
| assert result.schema.field(1).name == "b" |
| assert result.schema.field(2).name == "c" |
| assert result.schema.field(3).name == "d" |
| assert result.schema.field(4).name == "e" |
| assert result.schema.field(5).name == "f" |
| assert result.schema.field(6).name == "g" |
| |
| assert result.column(0) == pa.array([1, 2, 3]) |
| assert result.column(1) == pa.array([4, 5, 6]) |
| assert result.column(2) == pa.array([5, 7, 9]) |
| assert result.column(3) == pa.array([5, 7, 9]) |
| assert result.column(4) == pa.array([5, 7, 9]) |
| assert result.column(5) == pa.array([5, 7, 9]) |
| assert result.column(6) == pa.array([5, 7, 9]) |
| |
| |
| def test_cast(df): |
| df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())}) |
| expected = pa.schema( |
| [("a", pa.float16()), ("b", pa.list_(pa.uint32())), ("c", pa.int64())] |
| ) |
| |
| assert df.schema() == expected |
| |
| |
| def test_with_column_renamed(df): |
| df = df.with_column("c", column("a") + column("b")).with_column_renamed("c", "sum") |
| |
| result = df.collect()[0] |
| |
| assert result.schema.field(0).name == "a" |
| assert result.schema.field(1).name == "b" |
| assert result.schema.field(2).name == "sum" |
| |
| |
| def test_unnest(nested_df): |
| nested_df = nested_df.unnest_columns("a") |
| |
| # execute and collect the first (and only) batch |
| result = nested_df.collect()[0] |
| |
| assert result.column(0) == pa.array([1, 2, 3, 4, 5, 6, None]) |
| assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9, 10]) |
| |
| |
| def test_unnest_without_nulls(nested_df): |
| nested_df = nested_df.unnest_columns("a", preserve_nulls=False) |
| |
| # execute and collect the first (and only) batch |
| result = nested_df.collect()[0] |
| |
| assert result.column(0) == pa.array([1, 2, 3, 4, 5, 6]) |
| assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9]) |
| |
| |
| @pytest.mark.filterwarnings("ignore:`join_keys`:DeprecationWarning") |
| def test_join(): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df = ctx.create_dataframe([[batch]], "l") |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2]), pa.array([8, 10])], |
| names=["a", "c"], |
| ) |
| df1 = ctx.create_dataframe([[batch]], "r") |
| |
| df2 = df.join(df1, on="a", how="inner") |
| df2.show() |
| df2 = df2.sort(column("l.a")) |
| table = pa.Table.from_batches(df2.collect()) |
| |
| expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} |
| assert table.to_pydict() == expected |
| |
| df2 = df.join(df1, left_on="a", right_on="a", how="inner") |
| df2.show() |
| df2 = df2.sort(column("l.a")) |
| table = pa.Table.from_batches(df2.collect()) |
| |
| expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} |
| assert table.to_pydict() == expected |
| |
| # Verify we don't make a breaking change to pre-43.0.0 |
| # where users would pass join_keys as a positional argument |
| df2 = df.join(df1, (["a"], ["a"]), how="inner") |
| df2.show() |
| df2 = df2.sort(column("l.a")) |
| table = pa.Table.from_batches(df2.collect()) |
| |
| expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} |
| assert table.to_pydict() == expected |
| |
| |
| def test_join_invalid_params(): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df = ctx.create_dataframe([[batch]], "l") |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2]), pa.array([8, 10])], |
| names=["a", "c"], |
| ) |
| df1 = ctx.create_dataframe([[batch]], "r") |
| |
| with pytest.deprecated_call(): |
| df2 = df.join(df1, join_keys=(["a"], ["a"]), how="inner") |
| df2.show() |
| df2 = df2.sort(column("l.a")) |
| table = pa.Table.from_batches(df2.collect()) |
| |
| expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} |
| assert table.to_pydict() == expected |
| |
| with pytest.raises( |
| ValueError, match=r"`left_on` or `right_on` should not provided with `on`" |
| ): |
| df2 = df.join(df1, on="a", how="inner", right_on="test") |
| |
| with pytest.raises( |
| ValueError, match=r"`left_on` and `right_on` should both be provided." |
| ): |
| df2 = df.join(df1, left_on="a", how="inner") |
| |
| with pytest.raises( |
| ValueError, match=r"either `on` or `left_on` and `right_on` should be provided." |
| ): |
| df2 = df.join(df1, how="inner") |
| |
| |
| def test_join_on(): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df = ctx.create_dataframe([[batch]], "l") |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2]), pa.array([-8, 10])], |
| names=["a", "c"], |
| ) |
| df1 = ctx.create_dataframe([[batch]], "r") |
| |
| df2 = df.join_on(df1, column("l.a").__eq__(column("r.a")), how="inner") |
| df2.show() |
| df2 = df2.sort(column("l.a")) |
| table = pa.Table.from_batches(df2.collect()) |
| |
| expected = {"a": [1, 2], "c": [-8, 10], "b": [4, 5]} |
| assert table.to_pydict() == expected |
| |
| df3 = df.join_on( |
| df1, |
| column("l.a").__eq__(column("r.a")), |
| column("l.a").__lt__(column("r.c")), |
| how="inner", |
| ) |
| df3.show() |
| df3 = df3.sort(column("l.a")) |
| table = pa.Table.from_batches(df3.collect()) |
| expected = {"a": [2], "c": [10], "b": [5]} |
| assert table.to_pydict() == expected |
| |
| |
| def test_distinct(): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3, 1, 2, 3]), pa.array([4, 5, 6, 4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df_a = ctx.create_dataframe([[batch]]).distinct().sort(column("a")) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df_b = ctx.create_dataframe([[batch]]).sort(column("a")) |
| |
| assert df_a.collect() == df_b.collect() |
| |
| |
| data_test_window_functions = [ |
| ( |
| "row", |
| f.row_number(order_by=[column("b"), column("a").sort(ascending=False)]), |
| [4, 2, 3, 5, 7, 1, 6], |
| ), |
| ( |
| "row_w_params", |
| f.row_number( |
| order_by=[column("b"), column("a")], |
| partition_by=[column("c")], |
| ), |
| [2, 1, 3, 4, 2, 1, 3], |
| ), |
| ("rank", f.rank(order_by=[column("b")]), [3, 1, 3, 5, 6, 1, 6]), |
| ( |
| "rank_w_params", |
| f.rank(order_by=[column("b"), column("a")], partition_by=[column("c")]), |
| [2, 1, 3, 4, 2, 1, 3], |
| ), |
| ( |
| "dense_rank", |
| f.dense_rank(order_by=[column("b")]), |
| [2, 1, 2, 3, 4, 1, 4], |
| ), |
| ( |
| "dense_rank_w_params", |
| f.dense_rank(order_by=[column("b"), column("a")], partition_by=[column("c")]), |
| [2, 1, 3, 4, 2, 1, 3], |
| ), |
| ( |
| "percent_rank", |
| f.round(f.percent_rank(order_by=[column("b")]), literal(3)), |
| [0.333, 0.0, 0.333, 0.667, 0.833, 0.0, 0.833], |
| ), |
| ( |
| "percent_rank_w_params", |
| f.round( |
| f.percent_rank( |
| order_by=[column("b"), column("a")], partition_by=[column("c")] |
| ), |
| literal(3), |
| ), |
| [0.333, 0.0, 0.667, 1.0, 0.5, 0.0, 1.0], |
| ), |
| ( |
| "cume_dist", |
| f.round(f.cume_dist(order_by=[column("b")]), literal(3)), |
| [0.571, 0.286, 0.571, 0.714, 1.0, 0.286, 1.0], |
| ), |
| ( |
| "cume_dist_w_params", |
| f.round( |
| f.cume_dist( |
| order_by=[column("b"), column("a")], partition_by=[column("c")] |
| ), |
| literal(3), |
| ), |
| [0.5, 0.25, 0.75, 1.0, 0.667, 0.333, 1.0], |
| ), |
| ( |
| "ntile", |
| f.ntile(2, order_by=[column("b")]), |
| [1, 1, 1, 2, 2, 1, 2], |
| ), |
| ( |
| "ntile_w_params", |
| f.ntile(2, order_by=[column("b"), column("a")], partition_by=[column("c")]), |
| [1, 1, 2, 2, 1, 1, 2], |
| ), |
| ("lead", f.lead(column("b"), order_by=[column("b")]), [7, None, 8, 9, 9, 7, None]), |
| ( |
| "lead_w_params", |
| f.lead( |
| column("b"), |
| shift_offset=2, |
| default_value=-1, |
| order_by=[column("b"), column("a")], |
| partition_by=[column("c")], |
| ), |
| [8, 7, -1, -1, -1, 9, -1], |
| ), |
| ("lag", f.lag(column("b"), order_by=[column("b")]), [None, None, 7, 7, 8, None, 9]), |
| ( |
| "lag_w_params", |
| f.lag( |
| column("b"), |
| shift_offset=2, |
| default_value=-1, |
| order_by=[column("b"), column("a")], |
| partition_by=[column("c")], |
| ), |
| [-1, -1, None, 7, -1, -1, None], |
| ), |
| ( |
| "first_value", |
| f.first_value(column("a")).over( |
| Window(partition_by=[column("c")], order_by=[column("b")]) |
| ), |
| [1, 1, 1, 1, 5, 5, 5], |
| ), |
| ( |
| "last_value", |
| f.last_value(column("a")).over( |
| Window( |
| partition_by=[column("c")], |
| order_by=[column("b")], |
| window_frame=WindowFrame("rows", None, None), |
| ) |
| ), |
| [3, 3, 3, 3, 6, 6, 6], |
| ), |
| ( |
| "3rd_value", |
| f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])), |
| [None, None, 7, 7, 7, 7, 7], |
| ), |
| ( |
| "avg", |
| f.round(f.avg(column("b")).over(Window(order_by=[column("a")])), literal(3)), |
| [7.0, 7.0, 7.0, 7.333, 7.75, 7.75, 8.0], |
| ), |
| ] |
| |
| |
| @pytest.mark.parametrize(("name", "expr", "result"), data_test_window_functions) |
| def test_window_functions(partitioned_df, name, expr, result): |
| df = partitioned_df.select( |
| column("a"), column("b"), column("c"), f.alias(expr, name) |
| ) |
| df.sort(column("a")).show() |
| table = pa.Table.from_batches(df.collect()) |
| |
| expected = { |
| "a": [0, 1, 2, 3, 4, 5, 6], |
| "b": [7, None, 7, 8, 9, None, 9], |
| "c": ["A", "A", "A", "A", "B", "B", "B"], |
| name: result, |
| } |
| |
| assert table.sort_by("a").to_pydict() == expected |
| |
| |
| @pytest.mark.parametrize( |
| ("units", "start_bound", "end_bound"), |
| [ |
| (units, start_bound, end_bound) |
| for units in ("rows", "range") |
| for start_bound in (None, 0, 1) |
| for end_bound in (None, 0, 1) |
| ] |
| + [ |
| ("groups", 0, 0), |
| ], |
| ) |
| def test_valid_window_frame(units, start_bound, end_bound): |
| WindowFrame(units, start_bound, end_bound) |
| |
| |
| @pytest.mark.parametrize( |
| ("units", "start_bound", "end_bound"), |
| [ |
| ("invalid-units", 0, None), |
| ("invalid-units", None, 0), |
| ("invalid-units", None, None), |
| ("groups", None, 0), |
| ("groups", 0, None), |
| ("groups", None, None), |
| ], |
| ) |
| def test_invalid_window_frame(units, start_bound, end_bound): |
| with pytest.raises(RuntimeError): |
| WindowFrame(units, start_bound, end_bound) |
| |
| |
| def test_window_frame_defaults_match_postgres(partitioned_df): |
| # ref: https://github.com/apache/datafusion-python/issues/688 |
| |
| window_frame = WindowFrame("rows", None, None) |
| |
| col_a = column("a") |
| |
| # Using `f.window` with or without an unbounded window_frame produces the same |
| # results. These tests are included as a regression check but can be removed when |
| # f.window() is deprecated in favor of using the .over() approach. |
| no_frame = f.window("avg", [col_a]).alias("no_frame") |
| with_frame = f.window("avg", [col_a], window_frame=window_frame).alias("with_frame") |
| df_1 = partitioned_df.select(col_a, no_frame, with_frame) |
| |
| expected = { |
| "a": [0, 1, 2, 3, 4, 5, 6], |
| "no_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], |
| "with_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], |
| } |
| |
| assert df_1.sort(col_a).to_pydict() == expected |
| |
| # When order is not set, the default frame should be unounded preceeding to |
| # unbounded following. When order is set, the default frame is unbounded preceeding |
| # to current row. |
| no_order = f.avg(col_a).over(Window()).alias("over_no_order") |
| with_order = f.avg(col_a).over(Window(order_by=[col_a])).alias("over_with_order") |
| df_2 = partitioned_df.select(col_a, no_order, with_order) |
| |
| expected = { |
| "a": [0, 1, 2, 3, 4, 5, 6], |
| "over_no_order": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], |
| "over_with_order": [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0], |
| } |
| |
| assert df_2.sort(col_a).to_pydict() == expected |
| |
| |
| def test_html_formatter_cell_dimension(df, clean_formatter_state): |
| """Test configuring the HTML formatter with different options.""" |
| # Configure with custom settings |
| configure_formatter( |
| max_width=500, |
| max_height=200, |
| enable_cell_expansion=False, |
| ) |
| |
| html_output = df._repr_html_() |
| |
| # Verify our configuration was applied |
| assert "max-height: 200px" in html_output |
| assert "max-width: 500px" in html_output |
| # With cell expansion disabled, we shouldn't see expandable-container elements |
| assert "expandable-container" not in html_output |
| |
| |
| def test_html_formatter_custom_style_provider(df, clean_formatter_state): |
| """Test using custom style providers with the HTML formatter.""" |
| |
| # Configure with custom style provider |
| configure_formatter(style_provider=CustomStyleProvider()) |
| |
| html_output = df._repr_html_() |
| |
| # Verify our custom styles were applied |
| assert "background-color: #4285f4" in html_output |
| assert "color: white" in html_output |
| assert "background-color: #f5f5f5" in html_output |
| |
| |
| def test_html_formatter_type_formatters(df, clean_formatter_state): |
| """Test registering custom type formatters for specific data types.""" |
| |
| # Get current formatter and register custom formatters |
| formatter = get_formatter() |
| |
| # Format integers with color based on value |
| # Using int as the type for the formatter will work since we convert |
| # Arrow scalar values to Python native types in _get_cell_value |
| def format_int(value): |
| return f'<span style="color: {"red" if value > 2 else "blue"}">{value}</span>' |
| |
| formatter.register_formatter(int, format_int) |
| |
| html_output = df._repr_html_() |
| |
| # Our test dataframe has values 1,2,3 so we should see: |
| assert '<span style="color: blue">1</span>' in html_output |
| |
| |
| def test_html_formatter_custom_cell_builder(df, clean_formatter_state): |
| """Test using a custom cell builder function.""" |
| |
| # Create a custom cell builder with distinct styling for different value ranges |
| def custom_cell_builder(value, row, col, table_id): |
| try: |
| num_value = int(value) |
| if num_value > 5: # Values > 5 get green background with indicator |
| return ( |
| '<td style="background-color: #d9f0d3" ' |
| f'data-test="high">{value}-high</td>' |
| ) |
| if num_value < 3: # Values < 3 get blue background with indicator |
| return ( |
| '<td style="background-color: #d3e9f0" ' |
| f'data-test="low">{value}-low</td>' |
| ) |
| except (ValueError, TypeError): |
| pass |
| |
| # Default styling for other cells (3, 4, 5) |
| return f'<td style="border: 1px solid #ddd" data-test="mid">{value}-mid</td>' |
| |
| # Set our custom cell builder |
| formatter = get_formatter() |
| formatter.set_custom_cell_builder(custom_cell_builder) |
| |
| html_output = df._repr_html_() |
| |
| # Extract cells with specific styling using regex |
| low_cells = re.findall( |
| r'<td style="background-color: #d3e9f0"[^>]*>(\d+)-low</td>', html_output |
| ) |
| mid_cells = re.findall( |
| r'<td style="border: 1px solid #ddd"[^>]*>(\d+)-mid</td>', html_output |
| ) |
| high_cells = re.findall( |
| r'<td style="background-color: #d9f0d3"[^>]*>(\d+)-high</td>', html_output |
| ) |
| |
| # Sort the extracted values for consistent comparison |
| low_cells = sorted(map(int, low_cells)) |
| mid_cells = sorted(map(int, mid_cells)) |
| high_cells = sorted(map(int, high_cells)) |
| |
| # Verify specific values have the correct styling applied |
| assert low_cells == [1, 2] # Values < 3 |
| assert mid_cells == [3, 4, 5, 5] # Values 3-5 |
| assert high_cells == [6, 8, 8] # Values > 5 |
| |
| # Verify the exact content with styling appears in the output |
| assert ( |
| '<td style="background-color: #d3e9f0" data-test="low">1-low</td>' |
| in html_output |
| ) |
| assert ( |
| '<td style="background-color: #d3e9f0" data-test="low">2-low</td>' |
| in html_output |
| ) |
| assert ( |
| '<td style="border: 1px solid #ddd" data-test="mid">3-mid</td>' in html_output |
| ) |
| assert ( |
| '<td style="border: 1px solid #ddd" data-test="mid">4-mid</td>' in html_output |
| ) |
| assert ( |
| '<td style="background-color: #d9f0d3" data-test="high">6-high</td>' |
| in html_output |
| ) |
| assert ( |
| '<td style="background-color: #d9f0d3" data-test="high">8-high</td>' |
| in html_output |
| ) |
| |
| # Count occurrences to ensure all cells are properly styled |
| assert html_output.count("-low</td>") == 2 # Two low values (1, 2) |
| assert html_output.count("-mid</td>") == 4 # Four mid values (3, 4, 5, 5) |
| assert html_output.count("-high</td>") == 3 # Three high values (6, 8, 8) |
| |
| # Create a custom cell builder that changes background color based on value |
| def custom_cell_builder(value, row, col, table_id): |
| # Handle numeric values regardless of their exact type |
| try: |
| num_value = int(value) |
| if num_value > 5: # Values > 5 get green background |
| return f'<td style="background-color: #d9f0d3">{value}</td>' |
| if num_value < 3: # Values < 3 get light blue background |
| return f'<td style="background-color: #d3e9f0">{value}</td>' |
| except (ValueError, TypeError): |
| pass |
| |
| # Default styling for other cells |
| return f'<td style="border: 1px solid #ddd">{value}</td>' |
| |
| # Set our custom cell builder |
| formatter = get_formatter() |
| formatter.set_custom_cell_builder(custom_cell_builder) |
| |
| html_output = df._repr_html_() |
| |
| # Verify our custom cell styling was applied |
| assert "background-color: #d3e9f0" in html_output # For values 1,2 |
| |
| |
| def test_html_formatter_custom_header_builder(df, clean_formatter_state): |
| """Test using a custom header builder function.""" |
| |
| # Create a custom header builder with tooltips |
| def custom_header_builder(field): |
| tooltips = { |
| "a": "Primary key column", |
| "b": "Secondary values", |
| "c": "Additional data", |
| } |
| tooltip = tooltips.get(field.name, "") |
| return ( |
| f'<th style="background-color: #333; color: white" ' |
| f'title="{tooltip}">{field.name}</th>' |
| ) |
| |
| # Set our custom header builder |
| formatter = get_formatter() |
| formatter.set_custom_header_builder(custom_header_builder) |
| |
| html_output = df._repr_html_() |
| |
| # Verify our custom headers were applied |
| assert 'title="Primary key column"' in html_output |
| assert 'title="Secondary values"' in html_output |
| assert "background-color: #333; color: white" in html_output |
| |
| |
| def test_html_formatter_complex_customization(df, clean_formatter_state): |
| """Test combining multiple customization options together.""" |
| |
| # Create a dark mode style provider |
| class DarkModeStyleProvider: |
| def get_cell_style(self) -> str: |
| return ( |
| "background-color: #222; color: #eee; " |
| "padding: 8px; border: 1px solid #444;" |
| ) |
| |
| def get_header_style(self) -> str: |
| return ( |
| "background-color: #111; color: #fff; padding: 10px; " |
| "border: 1px solid #333;" |
| ) |
| |
| # Configure with dark mode style |
| configure_formatter( |
| max_cell_length=10, |
| style_provider=DarkModeStyleProvider(), |
| custom_css=""" |
| .datafusion-table { |
| font-family: monospace; |
| border-collapse: collapse; |
| } |
| .datafusion-table tr:hover td { |
| background-color: #444 !important; |
| } |
| """, |
| ) |
| |
| # Add type formatters for special formatting - now working with native int values |
| formatter = get_formatter() |
| formatter.register_formatter( |
| int, |
| lambda n: f'<span style="color: {"#5af" if n % 2 == 0 else "#f5a"}">{n}</span>', |
| ) |
| |
| html_output = df._repr_html_() |
| |
| # Verify our customizations were applied |
| assert "background-color: #222" in html_output |
| assert "background-color: #111" in html_output |
| assert ".datafusion-table" in html_output |
| assert "color: #5af" in html_output # Even numbers |
| |
| |
| def test_html_formatter_memory(df, clean_formatter_state): |
| """Test the memory and row control parameters in DataFrameHtmlFormatter.""" |
| configure_formatter(max_memory_bytes=10, min_rows_display=1) |
| html_output = df._repr_html_() |
| |
| # Count the number of table rows in the output |
| tr_count = count_table_rows(html_output) |
| # With a tiny memory limit of 10 bytes, the formatter should display |
| # the minimum number of rows (1) plus a message about truncation |
| assert tr_count == 2 # 1 for header row, 1 for data row |
| assert "data truncated" in html_output.lower() |
| |
| configure_formatter(max_memory_bytes=10 * MB, min_rows_display=1) |
| html_output = df._repr_html_() |
| # With larger memory limit and min_rows=2, should display all rows |
| tr_count = count_table_rows(html_output) |
| # Table should have header row (1) + 3 data rows = 4 rows |
| assert tr_count == 4 |
| # No truncation message should appear |
| assert "data truncated" not in html_output.lower() |
| |
| |
| def test_html_formatter_repr_rows(df, clean_formatter_state): |
| configure_formatter(min_rows_display=2, repr_rows=2) |
| html_output = df._repr_html_() |
| |
| tr_count = count_table_rows(html_output) |
| # Tabe should have header row (1) + 2 data rows = 3 rows |
| assert tr_count == 3 |
| |
| configure_formatter(min_rows_display=2, repr_rows=3) |
| html_output = df._repr_html_() |
| |
| tr_count = count_table_rows(html_output) |
| # Tabe should have header row (1) + 3 data rows = 4 rows |
| assert tr_count == 4 |
| |
| |
| def test_html_formatter_validation(): |
| # Test validation for invalid parameters |
| |
| with pytest.raises(ValueError, match="max_cell_length must be a positive integer"): |
| DataFrameHtmlFormatter(max_cell_length=0) |
| |
| with pytest.raises(ValueError, match="max_width must be a positive integer"): |
| DataFrameHtmlFormatter(max_width=0) |
| |
| with pytest.raises(ValueError, match="max_height must be a positive integer"): |
| DataFrameHtmlFormatter(max_height=0) |
| |
| with pytest.raises(ValueError, match="max_memory_bytes must be a positive integer"): |
| DataFrameHtmlFormatter(max_memory_bytes=0) |
| |
| with pytest.raises(ValueError, match="max_memory_bytes must be a positive integer"): |
| DataFrameHtmlFormatter(max_memory_bytes=-100) |
| |
| with pytest.raises(ValueError, match="min_rows_display must be a positive integer"): |
| DataFrameHtmlFormatter(min_rows_display=0) |
| |
| with pytest.raises(ValueError, match="min_rows_display must be a positive integer"): |
| DataFrameHtmlFormatter(min_rows_display=-5) |
| |
| with pytest.raises(ValueError, match="repr_rows must be a positive integer"): |
| DataFrameHtmlFormatter(repr_rows=0) |
| |
| with pytest.raises(ValueError, match="repr_rows must be a positive integer"): |
| DataFrameHtmlFormatter(repr_rows=-10) |
| |
| |
| def test_configure_formatter(df, clean_formatter_state): |
| """Test using custom style providers with the HTML formatter and configured |
| parameters.""" |
| |
| # these are non-default values |
| max_cell_length = 10 |
| max_width = 500 |
| max_height = 30 |
| max_memory_bytes = 3 * MB |
| min_rows_display = 2 |
| repr_rows = 2 |
| enable_cell_expansion = False |
| show_truncation_message = False |
| use_shared_styles = False |
| |
| reset_formatter() |
| formatter_default = get_formatter() |
| |
| assert formatter_default.max_cell_length != max_cell_length |
| assert formatter_default.max_width != max_width |
| assert formatter_default.max_height != max_height |
| assert formatter_default.max_memory_bytes != max_memory_bytes |
| assert formatter_default.min_rows_display != min_rows_display |
| assert formatter_default.repr_rows != repr_rows |
| assert formatter_default.enable_cell_expansion != enable_cell_expansion |
| assert formatter_default.show_truncation_message != show_truncation_message |
| assert formatter_default.use_shared_styles != use_shared_styles |
| |
| # Configure with custom style provider and additional parameters |
| configure_formatter( |
| max_cell_length=max_cell_length, |
| max_width=max_width, |
| max_height=max_height, |
| max_memory_bytes=max_memory_bytes, |
| min_rows_display=min_rows_display, |
| repr_rows=repr_rows, |
| enable_cell_expansion=enable_cell_expansion, |
| show_truncation_message=show_truncation_message, |
| use_shared_styles=use_shared_styles, |
| ) |
| formatter_custom = get_formatter() |
| assert formatter_custom.max_cell_length == max_cell_length |
| assert formatter_custom.max_width == max_width |
| assert formatter_custom.max_height == max_height |
| assert formatter_custom.max_memory_bytes == max_memory_bytes |
| assert formatter_custom.min_rows_display == min_rows_display |
| assert formatter_custom.repr_rows == repr_rows |
| assert formatter_custom.enable_cell_expansion == enable_cell_expansion |
| assert formatter_custom.show_truncation_message == show_truncation_message |
| assert formatter_custom.use_shared_styles == use_shared_styles |
| |
| |
| def test_configure_formatter_invalid_params(clean_formatter_state): |
| """Test that configure_formatter rejects invalid parameters.""" |
| with pytest.raises(ValueError, match="Invalid formatter parameters"): |
| configure_formatter(invalid_param=123) |
| |
| # Test with multiple parameters, one valid and one invalid |
| with pytest.raises(ValueError, match="Invalid formatter parameters"): |
| configure_formatter(max_width=500, not_a_real_param="test") |
| |
| # Test with multiple invalid parameters |
| with pytest.raises(ValueError, match="Invalid formatter parameters"): |
| configure_formatter(fake_param1="test", fake_param2=456) |
| |
| |
| def test_get_dataframe(tmp_path): |
| ctx = SessionContext() |
| |
| path = tmp_path / "test.csv" |
| table = pa.Table.from_arrays( |
| [ |
| [1, 2, 3, 4], |
| ["a", "b", "c", "d"], |
| [1.1, 2.2, 3.3, 4.4], |
| ], |
| names=["int", "str", "float"], |
| ) |
| write_csv(table, path) |
| |
| ctx.register_csv("csv", path) |
| |
| df = ctx.table("csv") |
| assert isinstance(df, DataFrame) |
| |
| |
| def test_struct_select(struct_df): |
| df = struct_df.select( |
| column("a")["c"] + column("b"), |
| column("a")["c"] - column("b"), |
| ) |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert result.column(0) == pa.array([5, 7, 9]) |
| assert result.column(1) == pa.array([-3, -3, -3]) |
| |
| |
| def test_explain(df): |
| df = df.select( |
| column("a") + column("b"), |
| column("a") - column("b"), |
| ) |
| df.explain() |
| |
| |
| def test_logical_plan(aggregate_df): |
| plan = aggregate_df.logical_plan() |
| |
| expected = "Projection: test.c1, sum(test.c2)" |
| |
| assert expected == plan.display() |
| |
| expected = ( |
| "Projection: test.c1, sum(test.c2)\n" |
| " Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]\n" |
| " TableScan: test" |
| ) |
| |
| assert expected == plan.display_indent() |
| |
| |
| def test_optimized_logical_plan(aggregate_df): |
| plan = aggregate_df.optimized_logical_plan() |
| |
| expected = "Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]" |
| |
| assert expected == plan.display() |
| |
| expected = ( |
| "Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]\n" |
| " TableScan: test projection=[c1, c2]" |
| ) |
| |
| assert expected == plan.display_indent() |
| |
| |
| def test_execution_plan(aggregate_df): |
| plan = aggregate_df.execution_plan() |
| |
| expected = ( |
| "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n" |
| ) |
| |
| assert expected == plan.display() |
| |
| # Check the number of partitions is as expected. |
| assert isinstance(plan.partition_count, int) |
| |
| expected = ( |
| "ProjectionExec: expr=[c1@0 as c1, SUM(test.c2)@1 as SUM(test.c2)]\n" |
| " Aggregate: groupBy=[[test.c1]], aggr=[[SUM(test.c2)]]\n" |
| " TableScan: test projection=[c1, c2]" |
| ) |
| |
| indent = plan.display_indent() |
| |
| # indent plan will be different for everyone due to absolute path |
| # to filename, so we just check for some expected content |
| assert "AggregateExec:" in indent |
| assert "CoalesceBatchesExec:" in indent |
| assert "RepartitionExec:" in indent |
| assert "DataSourceExec:" in indent |
| assert "file_type=csv" in indent |
| |
| ctx = SessionContext() |
| rows_returned = 0 |
| for idx in range(plan.partition_count): |
| stream = ctx.execute(plan, idx) |
| try: |
| batch = stream.next() |
| assert batch is not None |
| rows_returned += len(batch.to_pyarrow()[0]) |
| except StopIteration: |
| # This is one of the partitions with no values |
| pass |
| with pytest.raises(StopIteration): |
| stream.next() |
| |
| assert rows_returned == 5 |
| |
| |
| @pytest.mark.asyncio |
| async def test_async_iteration_of_df(aggregate_df): |
| rows_returned = 0 |
| async for batch in aggregate_df.execute_stream(): |
| assert batch is not None |
| rows_returned += len(batch.to_pyarrow()[0]) |
| |
| assert rows_returned == 5 |
| |
| |
| def test_repartition(df): |
| df.repartition(2) |
| |
| |
| def test_repartition_by_hash(df): |
| df.repartition_by_hash(column("a"), num=2) |
| |
| |
| def test_intersect(): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df_a = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([3, 4, 5]), pa.array([6, 7, 8])], |
| names=["a", "b"], |
| ) |
| df_b = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([3]), pa.array([6])], |
| names=["a", "b"], |
| ) |
| df_c = ctx.create_dataframe([[batch]]).sort(column("a")) |
| |
| df_a_i_b = df_a.intersect(df_b).sort(column("a")) |
| |
| assert df_c.collect() == df_a_i_b.collect() |
| |
| |
| def test_except_all(): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df_a = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([3, 4, 5]), pa.array([6, 7, 8])], |
| names=["a", "b"], |
| ) |
| df_b = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2]), pa.array([4, 5])], |
| names=["a", "b"], |
| ) |
| df_c = ctx.create_dataframe([[batch]]).sort(column("a")) |
| |
| df_a_e_b = df_a.except_all(df_b).sort(column("a")) |
| |
| assert df_c.collect() == df_a_e_b.collect() |
| |
| |
| def test_collect_partitioned(): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| |
| assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned() |
| |
| |
| def test_union(ctx): |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df_a = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([3, 4, 5]), pa.array([6, 7, 8])], |
| names=["a", "b"], |
| ) |
| df_b = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3, 3, 4, 5]), pa.array([4, 5, 6, 6, 7, 8])], |
| names=["a", "b"], |
| ) |
| df_c = ctx.create_dataframe([[batch]]).sort(column("a")) |
| |
| df_a_u_b = df_a.union(df_b).sort(column("a")) |
| |
| assert df_c.collect() == df_a_u_b.collect() |
| |
| |
| def test_union_distinct(ctx): |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df_a = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([3, 4, 5]), pa.array([6, 7, 8])], |
| names=["a", "b"], |
| ) |
| df_b = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3, 4, 5]), pa.array([4, 5, 6, 7, 8])], |
| names=["a", "b"], |
| ) |
| df_c = ctx.create_dataframe([[batch]]).sort(column("a")) |
| |
| df_a_u_b = df_a.union(df_b, distinct=True).sort(column("a")) |
| |
| assert df_c.collect() == df_a_u_b.collect() |
| assert df_c.collect() == df_a_u_b.collect() |
| |
| |
| def test_cache(df): |
| assert df.cache().collect() == df.collect() |
| |
| |
| def test_count(df): |
| # Get number of rows |
| assert df.count() == 3 |
| |
| |
| def test_to_pandas(df): |
| # Skip test if pandas is not installed |
| pd = pytest.importorskip("pandas") |
| |
| # Convert datafusion dataframe to pandas dataframe |
| pandas_df = df.to_pandas() |
| assert isinstance(pandas_df, pd.DataFrame) |
| assert pandas_df.shape == (3, 3) |
| assert set(pandas_df.columns) == {"a", "b", "c"} |
| |
| |
| def test_empty_to_pandas(df): |
| # Skip test if pandas is not installed |
| pd = pytest.importorskip("pandas") |
| |
| # Convert empty datafusion dataframe to pandas dataframe |
| pandas_df = df.limit(0).to_pandas() |
| assert isinstance(pandas_df, pd.DataFrame) |
| assert pandas_df.shape == (0, 3) |
| assert set(pandas_df.columns) == {"a", "b", "c"} |
| |
| |
| def test_to_polars(df): |
| # Skip test if polars is not installed |
| pl = pytest.importorskip("polars") |
| |
| # Convert datafusion dataframe to polars dataframe |
| polars_df = df.to_polars() |
| assert isinstance(polars_df, pl.DataFrame) |
| assert polars_df.shape == (3, 3) |
| assert set(polars_df.columns) == {"a", "b", "c"} |
| |
| |
| def test_empty_to_polars(df): |
| # Skip test if polars is not installed |
| pl = pytest.importorskip("polars") |
| |
| # Convert empty datafusion dataframe to polars dataframe |
| polars_df = df.limit(0).to_polars() |
| assert isinstance(polars_df, pl.DataFrame) |
| assert polars_df.shape == (0, 3) |
| assert set(polars_df.columns) == {"a", "b", "c"} |
| |
| |
| def test_to_arrow_table(df): |
| # Convert datafusion dataframe to pyarrow Table |
| pyarrow_table = df.to_arrow_table() |
| assert isinstance(pyarrow_table, pa.Table) |
| assert pyarrow_table.shape == (3, 3) |
| assert set(pyarrow_table.column_names) == {"a", "b", "c"} |
| |
| |
| def test_execute_stream(df): |
| stream = df.execute_stream() |
| assert all(batch is not None for batch in stream) |
| assert not list(stream) # after one iteration the generator must be exhausted |
| |
| |
| @pytest.mark.asyncio |
| async def test_execute_stream_async(df): |
| stream = df.execute_stream() |
| batches = [batch async for batch in stream] |
| |
| assert all(batch is not None for batch in batches) |
| |
| # After consuming all batches, the stream should be exhausted |
| remaining_batches = [batch async for batch in stream] |
| assert not remaining_batches |
| |
| |
| @pytest.mark.parametrize("schema", [True, False]) |
| def test_execute_stream_to_arrow_table(df, schema): |
| stream = df.execute_stream() |
| |
| if schema: |
| pyarrow_table = pa.Table.from_batches( |
| (batch.to_pyarrow() for batch in stream), schema=df.schema() |
| ) |
| else: |
| pyarrow_table = pa.Table.from_batches(batch.to_pyarrow() for batch in stream) |
| |
| assert isinstance(pyarrow_table, pa.Table) |
| assert pyarrow_table.shape == (3, 3) |
| assert set(pyarrow_table.column_names) == {"a", "b", "c"} |
| |
| |
| @pytest.mark.asyncio |
| @pytest.mark.parametrize("schema", [True, False]) |
| async def test_execute_stream_to_arrow_table_async(df, schema): |
| stream = df.execute_stream() |
| |
| if schema: |
| pyarrow_table = pa.Table.from_batches( |
| [batch.to_pyarrow() async for batch in stream], schema=df.schema() |
| ) |
| else: |
| pyarrow_table = pa.Table.from_batches( |
| [batch.to_pyarrow() async for batch in stream] |
| ) |
| |
| assert isinstance(pyarrow_table, pa.Table) |
| assert pyarrow_table.shape == (3, 3) |
| assert set(pyarrow_table.column_names) == {"a", "b", "c"} |
| |
| |
| def test_execute_stream_partitioned(df): |
| streams = df.execute_stream_partitioned() |
| assert all(batch is not None for stream in streams for batch in stream) |
| assert all( |
| not list(stream) for stream in streams |
| ) # after one iteration all generators must be exhausted |
| |
| |
| @pytest.mark.asyncio |
| async def test_execute_stream_partitioned_async(df): |
| streams = df.execute_stream_partitioned() |
| |
| for stream in streams: |
| batches = [batch async for batch in stream] |
| assert all(batch is not None for batch in batches) |
| |
| # Ensure the stream is exhausted after iteration |
| remaining_batches = [batch async for batch in stream] |
| assert not remaining_batches |
| |
| |
| def test_empty_to_arrow_table(df): |
| # Convert empty datafusion dataframe to pyarrow Table |
| pyarrow_table = df.limit(0).to_arrow_table() |
| assert isinstance(pyarrow_table, pa.Table) |
| assert pyarrow_table.shape == (0, 3) |
| assert set(pyarrow_table.column_names) == {"a", "b", "c"} |
| |
| |
| def test_to_pylist(df): |
| # Convert datafusion dataframe to Python list |
| pylist = df.to_pylist() |
| assert isinstance(pylist, list) |
| assert pylist == [ |
| {"a": 1, "b": 4, "c": 8}, |
| {"a": 2, "b": 5, "c": 5}, |
| {"a": 3, "b": 6, "c": 8}, |
| ] |
| |
| |
| def test_to_pydict(df): |
| # Convert datafusion dataframe to Python dictionary |
| pydict = df.to_pydict() |
| assert isinstance(pydict, dict) |
| assert pydict == {"a": [1, 2, 3], "b": [4, 5, 6], "c": [8, 5, 8]} |
| |
| |
| def test_describe(df): |
| # Calculate statistics |
| df = df.describe() |
| |
| # Collect the result |
| result = df.to_pydict() |
| |
| assert result == { |
| "describe": [ |
| "count", |
| "null_count", |
| "mean", |
| "std", |
| "min", |
| "max", |
| "median", |
| ], |
| "a": [3.0, 0.0, 2.0, 1.0, 1.0, 3.0, 2.0], |
| "b": [3.0, 0.0, 5.0, 1.0, 4.0, 6.0, 5.0], |
| "c": [3.0, 0.0, 7.0, 1.7320508075688772, 5.0, 8.0, 8.0], |
| } |
| |
| |
| @pytest.mark.parametrize("path_to_str", [True, False]) |
| def test_write_csv(ctx, df, tmp_path, path_to_str): |
| path = str(tmp_path) if path_to_str else tmp_path |
| |
| df.write_csv(path, with_header=True) |
| |
| ctx.register_csv("csv", path) |
| result = ctx.table("csv").to_pydict() |
| expected = df.to_pydict() |
| |
| assert result == expected |
| |
| |
| @pytest.mark.parametrize("path_to_str", [True, False]) |
| def test_write_json(ctx, df, tmp_path, path_to_str): |
| path = str(tmp_path) if path_to_str else tmp_path |
| |
| df.write_json(path) |
| |
| ctx.register_json("json", path) |
| result = ctx.table("json").to_pydict() |
| expected = df.to_pydict() |
| |
| assert result == expected |
| |
| |
| @pytest.mark.parametrize("path_to_str", [True, False]) |
| def test_write_parquet(df, tmp_path, path_to_str): |
| path = str(tmp_path) if path_to_str else tmp_path |
| |
| df.write_parquet(str(path)) |
| result = pq.read_table(str(path)).to_pydict() |
| expected = df.to_pydict() |
| |
| assert result == expected |
| |
| |
| @pytest.mark.parametrize( |
| ("compression", "compression_level"), |
| [("gzip", 6), ("brotli", 7), ("zstd", 15)], |
| ) |
| def test_write_compressed_parquet(df, tmp_path, compression, compression_level): |
| path = tmp_path |
| |
| df.write_parquet( |
| str(path), compression=compression, compression_level=compression_level |
| ) |
| |
| # test that the actual compression scheme is the one written |
| for _root, _dirs, files in os.walk(path): |
| for file in files: |
| if file.endswith(".parquet"): |
| metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict() |
| for row_group in metadata["row_groups"]: |
| for columns in row_group["columns"]: |
| assert columns["compression"].lower() == compression |
| |
| result = pq.read_table(str(path)).to_pydict() |
| expected = df.to_pydict() |
| |
| assert result == expected |
| |
| |
| @pytest.mark.parametrize( |
| ("compression", "compression_level"), |
| [("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)], |
| ) |
| def test_write_compressed_parquet_wrong_compression_level( |
| df, tmp_path, compression, compression_level |
| ): |
| path = tmp_path |
| |
| with pytest.raises(ValueError): |
| df.write_parquet( |
| str(path), |
| compression=compression, |
| compression_level=compression_level, |
| ) |
| |
| |
| @pytest.mark.parametrize("compression", ["wrong"]) |
| def test_write_compressed_parquet_invalid_compression(df, tmp_path, compression): |
| path = tmp_path |
| |
| with pytest.raises(ValueError): |
| df.write_parquet(str(path), compression=compression) |
| |
| |
| # not testing lzo because it it not implemented yet |
| # https://github.com/apache/arrow-rs/issues/6970 |
| @pytest.mark.parametrize("compression", ["zstd", "brotli", "gzip"]) |
| def test_write_compressed_parquet_default_compression_level(df, tmp_path, compression): |
| # Test write_parquet with zstd, brotli, gzip default compression level, |
| # ie don't specify compression level |
| # should complete without error |
| path = tmp_path |
| |
| df.write_parquet(str(path), compression=compression) |
| |
| |
| def test_write_parquet_with_options_default_compression(df, tmp_path): |
| """Test that the default compression is ZSTD.""" |
| df.write_parquet(tmp_path) |
| |
| for file in tmp_path.rglob("*.parquet"): |
| metadata = pq.ParquetFile(file).metadata.to_dict() |
| for row_group in metadata["row_groups"]: |
| for col in row_group["columns"]: |
| assert col["compression"].lower() == "zstd" |
| |
| |
| @pytest.mark.parametrize( |
| "compression", |
| ["gzip(6)", "brotli(7)", "zstd(15)", "snappy", "uncompressed"], |
| ) |
| def test_write_parquet_with_options_compression(df, tmp_path, compression): |
| import re |
| |
| path = tmp_path |
| df.write_parquet_with_options( |
| str(path), ParquetWriterOptions(compression=compression) |
| ) |
| |
| # test that the actual compression scheme is the one written |
| for _root, _dirs, files in os.walk(path): |
| for file in files: |
| if file.endswith(".parquet"): |
| metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict() |
| for row_group in metadata["row_groups"]: |
| for col in row_group["columns"]: |
| assert col["compression"].lower() == re.sub( |
| r"\(\d+\)", "", compression |
| ) |
| |
| result = pq.read_table(str(path)).to_pydict() |
| expected = df.to_pydict() |
| |
| assert result == expected |
| |
| |
| @pytest.mark.parametrize( |
| "compression", |
| ["gzip(12)", "brotli(15)", "zstd(23)"], |
| ) |
| def test_write_parquet_with_options_wrong_compression_level(df, tmp_path, compression): |
| path = tmp_path |
| |
| with pytest.raises(Exception, match=r"valid compression range .*? exceeded."): |
| df.write_parquet_with_options( |
| str(path), ParquetWriterOptions(compression=compression) |
| ) |
| |
| |
| @pytest.mark.parametrize("compression", ["wrong", "wrong(12)"]) |
| def test_write_parquet_with_options_invalid_compression(df, tmp_path, compression): |
| path = tmp_path |
| |
| with pytest.raises(Exception, match="Unknown or unsupported parquet compression"): |
| df.write_parquet_with_options( |
| str(path), ParquetWriterOptions(compression=compression) |
| ) |
| |
| |
| @pytest.mark.parametrize( |
| ("writer_version", "format_version"), |
| [("1.0", "1.0"), ("2.0", "2.6"), (None, "1.0")], |
| ) |
| def test_write_parquet_with_options_writer_version( |
| df, tmp_path, writer_version, format_version |
| ): |
| """Test the Parquet writer version. Note that writer_version=2.0 results in |
| format_version=2.6""" |
| if writer_version is None: |
| df.write_parquet_with_options(tmp_path, ParquetWriterOptions()) |
| else: |
| df.write_parquet_with_options( |
| tmp_path, ParquetWriterOptions(writer_version=writer_version) |
| ) |
| |
| for file in tmp_path.rglob("*.parquet"): |
| parquet = pq.ParquetFile(file) |
| metadata = parquet.metadata.to_dict() |
| assert metadata["format_version"] == format_version |
| |
| |
| @pytest.mark.parametrize("writer_version", ["1.2.3", "custom-version", "0"]) |
| def test_write_parquet_with_options_wrong_writer_version(df, tmp_path, writer_version): |
| """Test that invalid writer versions in Parquet throw an exception.""" |
| with pytest.raises( |
| Exception, match="Unknown or unsupported parquet writer version" |
| ): |
| df.write_parquet_with_options( |
| tmp_path, ParquetWriterOptions(writer_version=writer_version) |
| ) |
| |
| |
| @pytest.mark.parametrize("dictionary_enabled", [True, False, None]) |
| def test_write_parquet_with_options_dictionary_enabled( |
| df, tmp_path, dictionary_enabled |
| ): |
| """Test enabling/disabling the dictionaries in Parquet.""" |
| df.write_parquet_with_options( |
| tmp_path, ParquetWriterOptions(dictionary_enabled=dictionary_enabled) |
| ) |
| # by default, the dictionary is enabled, so None results in True |
| result = dictionary_enabled if dictionary_enabled is not None else True |
| |
| for file in tmp_path.rglob("*.parquet"): |
| parquet = pq.ParquetFile(file) |
| metadata = parquet.metadata.to_dict() |
| |
| for row_group in metadata["row_groups"]: |
| for col in row_group["columns"]: |
| assert col["has_dictionary_page"] == result |
| |
| |
| @pytest.mark.parametrize( |
| ("statistics_enabled", "has_statistics"), |
| [("page", True), ("chunk", True), ("none", False), (None, True)], |
| ) |
| def test_write_parquet_with_options_statistics_enabled( |
| df, tmp_path, statistics_enabled, has_statistics |
| ): |
| """Test configuring the statistics in Parquet. In pyarrow we can only check for |
| column-level statistics, so "page" and "chunk" are tested in the same way.""" |
| df.write_parquet_with_options( |
| tmp_path, ParquetWriterOptions(statistics_enabled=statistics_enabled) |
| ) |
| |
| for file in tmp_path.rglob("*.parquet"): |
| parquet = pq.ParquetFile(file) |
| metadata = parquet.metadata.to_dict() |
| |
| for row_group in metadata["row_groups"]: |
| for col in row_group["columns"]: |
| if has_statistics: |
| assert col["statistics"] is not None |
| else: |
| assert col["statistics"] is None |
| |
| |
| @pytest.mark.parametrize("max_row_group_size", [1000, 5000, 10000, 100000]) |
| def test_write_parquet_with_options_max_row_group_size( |
| large_df, tmp_path, max_row_group_size |
| ): |
| """Test configuring the max number of rows per group in Parquet. These test cases |
| guarantee that the number of rows for each row group is max_row_group_size, given |
| the total number of rows is a multiple of max_row_group_size.""" |
| large_df.write_parquet_with_options( |
| tmp_path, ParquetWriterOptions(max_row_group_size=max_row_group_size) |
| ) |
| |
| for file in tmp_path.rglob("*.parquet"): |
| parquet = pq.ParquetFile(file) |
| metadata = parquet.metadata.to_dict() |
| for row_group in metadata["row_groups"]: |
| assert row_group["num_rows"] == max_row_group_size |
| |
| |
| @pytest.mark.parametrize("created_by", ["datafusion", "datafusion-python", "custom"]) |
| def test_write_parquet_with_options_created_by(df, tmp_path, created_by): |
| """Test configuring the created by metadata in Parquet.""" |
| df.write_parquet_with_options(tmp_path, ParquetWriterOptions(created_by=created_by)) |
| |
| for file in tmp_path.rglob("*.parquet"): |
| parquet = pq.ParquetFile(file) |
| metadata = parquet.metadata.to_dict() |
| assert metadata["created_by"] == created_by |
| |
| |
| @pytest.mark.parametrize("statistics_truncate_length", [5, 25, 50]) |
| def test_write_parquet_with_options_statistics_truncate_length( |
| df, tmp_path, statistics_truncate_length |
| ): |
| """Test configuring the truncate limit in Parquet's row-group-level statistics.""" |
| ctx = SessionContext() |
| data = { |
| "a": [ |
| "a_the_quick_brown_fox_jumps_over_the_lazy_dog", |
| "m_the_quick_brown_fox_jumps_over_the_lazy_dog", |
| "z_the_quick_brown_fox_jumps_over_the_lazy_dog", |
| ], |
| "b": ["a_smaller", "m_smaller", "z_smaller"], |
| } |
| df = ctx.from_arrow(pa.record_batch(data)) |
| df.write_parquet_with_options( |
| tmp_path, |
| ParquetWriterOptions(statistics_truncate_length=statistics_truncate_length), |
| ) |
| |
| for file in tmp_path.rglob("*.parquet"): |
| parquet = pq.ParquetFile(file) |
| metadata = parquet.metadata.to_dict() |
| |
| for row_group in metadata["row_groups"]: |
| for col in row_group["columns"]: |
| statistics = col["statistics"] |
| assert len(statistics["min"]) <= statistics_truncate_length |
| assert len(statistics["max"]) <= statistics_truncate_length |
| |
| |
| def test_write_parquet_with_options_default_encoding(tmp_path): |
| """Test that, by default, Parquet files are written with dictionary encoding. |
| Note that dictionary encoding is not used for boolean values, so it is not tested |
| here.""" |
| ctx = SessionContext() |
| data = { |
| "a": [1, 2, 3], |
| "b": ["1", "2", "3"], |
| "c": [1.01, 2.02, 3.03], |
| } |
| df = ctx.from_arrow(pa.record_batch(data)) |
| df.write_parquet_with_options(tmp_path, ParquetWriterOptions()) |
| |
| for file in tmp_path.rglob("*.parquet"): |
| parquet = pq.ParquetFile(file) |
| metadata = parquet.metadata.to_dict() |
| |
| for row_group in metadata["row_groups"]: |
| for col in row_group["columns"]: |
| assert col["encodings"] == ("PLAIN", "RLE", "RLE_DICTIONARY") |
| |
| |
| @pytest.mark.parametrize( |
| ("encoding", "data_types", "result"), |
| [ |
| ("plain", ["int", "float", "str", "bool"], ("PLAIN", "RLE")), |
| ("rle", ["bool"], ("RLE",)), |
| ("delta_binary_packed", ["int"], ("RLE", "DELTA_BINARY_PACKED")), |
| ("delta_length_byte_array", ["str"], ("RLE", "DELTA_LENGTH_BYTE_ARRAY")), |
| ("delta_byte_array", ["str"], ("RLE", "DELTA_BYTE_ARRAY")), |
| ("byte_stream_split", ["int", "float"], ("RLE", "BYTE_STREAM_SPLIT")), |
| ], |
| ) |
| def test_write_parquet_with_options_encoding(tmp_path, encoding, data_types, result): |
| """Test different encodings in Parquet in their respective support column types.""" |
| ctx = SessionContext() |
| |
| data = {} |
| for data_type in data_types: |
| if data_type == "int": |
| data["int"] = [1, 2, 3] |
| elif data_type == "float": |
| data["float"] = [1.01, 2.02, 3.03] |
| elif data_type == "str": |
| data["str"] = ["a", "b", "c"] |
| elif data_type == "bool": |
| data["bool"] = [True, False, True] |
| |
| df = ctx.from_arrow(pa.record_batch(data)) |
| df.write_parquet_with_options( |
| tmp_path, ParquetWriterOptions(encoding=encoding, dictionary_enabled=False) |
| ) |
| |
| for file in tmp_path.rglob("*.parquet"): |
| parquet = pq.ParquetFile(file) |
| metadata = parquet.metadata.to_dict() |
| |
| for row_group in metadata["row_groups"]: |
| for col in row_group["columns"]: |
| assert col["encodings"] == result |
| |
| |
| @pytest.mark.parametrize("encoding", ["bit_packed"]) |
| def test_write_parquet_with_options_unsupported_encoding(df, tmp_path, encoding): |
| """Test that unsupported Parquet encodings do not work.""" |
| # BaseException is used since this throws a Rust panic: https://github.com/PyO3/pyo3/issues/3519 |
| with pytest.raises(BaseException, match="Encoding .*? is not supported"): |
| df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding)) |
| |
| |
| @pytest.mark.parametrize("encoding", ["non_existent", "unknown", "plain123"]) |
| def test_write_parquet_with_options_invalid_encoding(df, tmp_path, encoding): |
| """Test that invalid Parquet encodings do not work.""" |
| with pytest.raises(Exception, match="Unknown or unsupported parquet encoding"): |
| df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding)) |
| |
| |
| @pytest.mark.parametrize("encoding", ["plain_dictionary", "rle_dictionary"]) |
| def test_write_parquet_with_options_dictionary_encoding_fallback( |
| df, tmp_path, encoding |
| ): |
| """Test that the dictionary encoding cannot be used as fallback in Parquet.""" |
| # BaseException is used since this throws a Rust panic: https://github.com/PyO3/pyo3/issues/3519 |
| with pytest.raises( |
| BaseException, match="Dictionary encoding can not be used as fallback encoding" |
| ): |
| df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding)) |
| |
| |
| def test_write_parquet_with_options_bloom_filter(df, tmp_path): |
| """Test Parquet files with and without (default) bloom filters. Since pyarrow does |
| not expose any information about bloom filters, the easiest way to confirm that they |
| are actually written is to compare the file size.""" |
| path_no_bloom_filter = tmp_path / "1" |
| path_bloom_filter = tmp_path / "2" |
| |
| df.write_parquet_with_options(path_no_bloom_filter, ParquetWriterOptions()) |
| df.write_parquet_with_options( |
| path_bloom_filter, ParquetWriterOptions(bloom_filter_on_write=True) |
| ) |
| |
| size_no_bloom_filter = 0 |
| for file in path_no_bloom_filter.rglob("*.parquet"): |
| size_no_bloom_filter += os.path.getsize(file) |
| |
| size_bloom_filter = 0 |
| for file in path_bloom_filter.rglob("*.parquet"): |
| size_bloom_filter += os.path.getsize(file) |
| |
| assert size_no_bloom_filter < size_bloom_filter |
| |
| |
| def test_write_parquet_with_options_column_options(df, tmp_path): |
| """Test writing Parquet files with different options for each column, which replace |
| the global configs (when provided).""" |
| data = { |
| "a": [1, 2, 3], |
| "b": ["a", "b", "c"], |
| "c": [False, True, False], |
| "d": [1.01, 2.02, 3.03], |
| "e": [4, 5, 6], |
| } |
| |
| column_specific_options = { |
| "a": ParquetColumnOptions(statistics_enabled="none"), |
| "b": ParquetColumnOptions(encoding="plain", dictionary_enabled=False), |
| "c": ParquetColumnOptions( |
| compression="snappy", encoding="rle", dictionary_enabled=False |
| ), |
| "d": ParquetColumnOptions( |
| compression="zstd(6)", |
| encoding="byte_stream_split", |
| dictionary_enabled=False, |
| statistics_enabled="none", |
| ), |
| # column "e" will use the global configs |
| } |
| |
| results = { |
| "a": { |
| "statistics": False, |
| "compression": "brotli", |
| "encodings": ("PLAIN", "RLE", "RLE_DICTIONARY"), |
| }, |
| "b": { |
| "statistics": True, |
| "compression": "brotli", |
| "encodings": ("PLAIN", "RLE"), |
| }, |
| "c": { |
| "statistics": True, |
| "compression": "snappy", |
| "encodings": ("RLE",), |
| }, |
| "d": { |
| "statistics": False, |
| "compression": "zstd", |
| "encodings": ("RLE", "BYTE_STREAM_SPLIT"), |
| }, |
| "e": { |
| "statistics": True, |
| "compression": "brotli", |
| "encodings": ("PLAIN", "RLE", "RLE_DICTIONARY"), |
| }, |
| } |
| |
| ctx = SessionContext() |
| df = ctx.from_arrow(pa.record_batch(data)) |
| df.write_parquet_with_options( |
| tmp_path, |
| ParquetWriterOptions( |
| compression="brotli(8)", column_specific_options=column_specific_options |
| ), |
| ) |
| |
| for file in tmp_path.rglob("*.parquet"): |
| parquet = pq.ParquetFile(file) |
| metadata = parquet.metadata.to_dict() |
| |
| for row_group in metadata["row_groups"]: |
| for col in row_group["columns"]: |
| column_name = col["path_in_schema"] |
| result = results[column_name] |
| assert (col["statistics"] is not None) == result["statistics"] |
| assert col["compression"].lower() == result["compression"].lower() |
| assert col["encodings"] == result["encodings"] |
| |
| |
| def test_write_parquet_options(df, tmp_path): |
| options = ParquetWriterOptions(compression="gzip", compression_level=6) |
| df.write_parquet(str(tmp_path), options) |
| |
| result = pq.read_table(str(tmp_path)).to_pydict() |
| expected = df.to_pydict() |
| |
| assert result == expected |
| |
| |
| def test_write_parquet_options_error(df, tmp_path): |
| options = ParquetWriterOptions(compression="gzip", compression_level=6) |
| with pytest.raises(ValueError): |
| df.write_parquet(str(tmp_path), options, compression_level=1) |
| |
| |
| def test_dataframe_export(df) -> None: |
| # Guarantees that we have the canonical implementation |
| # reading our dataframe export |
| table = pa.table(df) |
| assert table.num_columns == 3 |
| assert table.num_rows == 3 |
| |
| desired_schema = pa.schema([("a", pa.int64())]) |
| |
| # Verify we can request a schema |
| table = pa.table(df, schema=desired_schema) |
| assert table.num_columns == 1 |
| assert table.num_rows == 3 |
| |
| # Expect a table of nulls if the schema don't overlap |
| desired_schema = pa.schema([("g", pa.string())]) |
| table = pa.table(df, schema=desired_schema) |
| assert table.num_columns == 1 |
| assert table.num_rows == 3 |
| for i in range(3): |
| assert table[0][i].as_py() is None |
| |
| # Expect an error when we cannot convert schema |
| desired_schema = pa.schema([("a", pa.float32())]) |
| failed_convert = False |
| try: |
| table = pa.table(df, schema=desired_schema) |
| except Exception: |
| failed_convert = True |
| assert failed_convert |
| |
| # Expect an error when we have a not set non-nullable |
| desired_schema = pa.schema([("g", pa.string(), False)]) |
| failed_convert = False |
| try: |
| table = pa.table(df, schema=desired_schema) |
| except Exception: |
| failed_convert = True |
| assert failed_convert |
| |
| |
| def test_dataframe_transform(df): |
| def add_string_col(df_internal) -> DataFrame: |
| return df_internal.with_column("string_col", literal("string data")) |
| |
| def add_with_parameter(df_internal, value: Any) -> DataFrame: |
| return df_internal.with_column("new_col", literal(value)) |
| |
| df = df.transform(add_string_col).transform(add_with_parameter, 3) |
| |
| result = df.to_pydict() |
| |
| assert result["a"] == [1, 2, 3] |
| assert result["string_col"] == ["string data" for _i in range(3)] |
| assert result["new_col"] == [3 for _i in range(3)] |
| |
| |
| def test_dataframe_repr_html_structure(df, clean_formatter_state) -> None: |
| """Test that DataFrame._repr_html_ produces expected HTML output structure.""" |
| |
| output = df._repr_html_() |
| |
| # Since we've added a fair bit of processing to the html output, lets just verify |
| # the values we are expecting in the table exist. Use regex and ignore everything |
| # between the <th></th> and <td></td>. We also don't want the closing > on the |
| # td and th segments because that is where the formatting data is written. |
| |
| headers = ["a", "b", "c"] |
| headers = [f"<th(.*?)>{v}</th>" for v in headers] |
| header_pattern = "(.*?)".join(headers) |
| header_matches = re.findall(header_pattern, output, re.DOTALL) |
| assert len(header_matches) == 1 |
| |
| # Update the pattern to handle values that may be wrapped in spans |
| body_data = [[1, 4, 8], [2, 5, 5], [3, 6, 8]] |
| |
| body_lines = [ |
| f"<td(.*?)>(?:<span[^>]*?>)?{v}(?:</span>)?</td>" |
| for inner in body_data |
| for v in inner |
| ] |
| body_pattern = "(.*?)".join(body_lines) |
| |
| body_matches = re.findall(body_pattern, output, re.DOTALL) |
| |
| assert len(body_matches) == 1, "Expected pattern of values not found in HTML output" |
| |
| |
| def test_dataframe_repr_html_values(df, clean_formatter_state): |
| """Test that DataFrame._repr_html_ contains the expected data values.""" |
| html = df._repr_html_() |
| assert html is not None |
| |
| # Create a more flexible pattern that handles values being wrapped in spans |
| # This pattern will match the sequence of values 1,4,8,2,5,5,3,6,8 regardless |
| # of formatting |
| pattern = re.compile( |
| r"<td[^>]*?>(?:<span[^>]*?>)?1(?:</span>)?</td>.*?" |
| r"<td[^>]*?>(?:<span[^>]*?>)?4(?:</span>)?</td>.*?" |
| r"<td[^>]*?>(?:<span[^>]*?>)?8(?:</span>)?</td>.*?" |
| r"<td[^>]*?>(?:<span[^>]*?>)?2(?:</span>)?</td>.*?" |
| r"<td[^>]*?>(?:<span[^>]*?>)?5(?:</span>)?</td>.*?" |
| r"<td[^>]*?>(?:<span[^>]*?>)?5(?:</span>)?</td>.*?" |
| r"<td[^>]*?>(?:<span[^>]*?>)?3(?:</span>)?</td>.*?" |
| r"<td[^>]*?>(?:<span[^>]*?>)?6(?:</span>)?</td>.*?" |
| r"<td[^>]*?>(?:<span[^>]*?>)?8(?:</span>)?</td>", |
| re.DOTALL, |
| ) |
| |
| # Print debug info if the test fails |
| matches = re.findall(pattern, html) |
| if not matches: |
| print(f"HTML output snippet: {html[:500]}...") # noqa: T201 |
| |
| assert len(matches) > 0, "Expected pattern of values not found in HTML output" |
| |
| |
| def test_html_formatter_shared_styles(df, clean_formatter_state): |
| """Test that shared styles work correctly across multiple tables.""" |
| |
| # First, ensure we're using shared styles |
| configure_formatter(use_shared_styles=True) |
| |
| html_first = df._repr_html_() |
| html_second = df._repr_html_() |
| |
| assert "<script>" in html_first |
| assert "df-styles" in html_first |
| assert "<script>" in html_second |
| assert "df-styles" in html_second |
| assert "<style>" not in html_first |
| assert "<style>" not in html_second |
| |
| |
| def test_html_formatter_no_shared_styles(df, clean_formatter_state): |
| """Test that styles are always included when shared styles are disabled.""" |
| |
| # Configure formatter to NOT use shared styles |
| configure_formatter(use_shared_styles=False) |
| |
| html_first = df._repr_html_() |
| html_second = df._repr_html_() |
| |
| assert "<script>" in html_first |
| assert "<script>" in html_second |
| assert "df-styles" in html_first |
| assert "df-styles" in html_second |
| assert "<style>" not in html_first |
| assert "<style>" not in html_second |
| |
| |
| def test_html_formatter_manual_format_html(clean_formatter_state): |
| """Test direct usage of format_html method with shared styles.""" |
| |
| # Create sample data |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| |
| formatter = get_formatter() |
| |
| html_first = formatter.format_html([batch], batch.schema) |
| html_second = formatter.format_html([batch], batch.schema) |
| |
| assert "<script>" in html_first |
| assert "<script>" in html_second |
| assert "df-styles" in html_first |
| assert "df-styles" in html_second |
| assert "<style>" not in html_first |
| assert "<style>" not in html_second |
| |
| # Create a new formatter with shared_styles=False |
| local_formatter = DataFrameHtmlFormatter(use_shared_styles=False) |
| |
| # Both calls should include styles |
| local_html_1 = local_formatter.format_html([batch], batch.schema) |
| local_html_2 = local_formatter.format_html([batch], batch.schema) |
| |
| assert "<script>" in local_html_1 |
| assert "<script>" in local_html_2 |
| assert "df-styles" in local_html_1 |
| assert "df-styles" in local_html_2 |
| assert "<style>" not in local_html_1 |
| assert "<style>" not in local_html_2 |
| |
| |
| def test_fill_null_basic(null_df): |
| """Test basic fill_null functionality with a single value.""" |
| # Fill all nulls with 0 |
| filled_df = null_df.fill_null(0) |
| |
| result = filled_df.collect()[0] |
| |
| # Check that nulls were filled with 0 (or equivalent) |
| assert result.column(0) == pa.array([1, 0, 3, 0]) |
| assert result.column(1) == pa.array([4.5, 6.7, 0.0, 0.0]) |
| # String column should be filled with "0" |
| assert result.column(2) == pa.array(["a", "0", "c", "0"]) |
| # Boolean column should be filled with False (0 converted to bool) |
| assert result.column(3) == pa.array([True, False, False, False]) |
| |
| |
| def test_fill_null_subset(null_df): |
| """Test filling nulls only in a subset of columns.""" |
| # Fill nulls only in numeric columns |
| filled_df = null_df.fill_null(0, subset=["int_col", "float_col"]) |
| |
| result = filled_df.collect()[0] |
| |
| # Check that nulls were filled only in specified columns |
| assert result.column(0) == pa.array([1, 0, 3, 0]) |
| assert result.column(1) == pa.array([4.5, 6.7, 0.0, 0.0]) |
| # These should still have nulls |
| assert None in result.column(2).to_pylist() |
| assert None in result.column(3).to_pylist() |
| |
| |
| def test_fill_null_str_column(null_df): |
| """Test filling nulls in string columns with different values.""" |
| # Fill string nulls with a replacement string |
| filled_df = null_df.fill_null("N/A", subset=["str_col"]) |
| |
| result = filled_df.collect()[0] |
| |
| # Check that string nulls were filled with "N/A" |
| assert result.column(2).to_pylist() == ["a", "N/A", "c", "N/A"] |
| |
| # Other columns should be unchanged |
| assert None in result.column(0).to_pylist() |
| assert None in result.column(1).to_pylist() |
| assert None in result.column(3).to_pylist() |
| |
| # Fill with an empty string |
| filled_df = null_df.fill_null("", subset=["str_col"]) |
| result = filled_df.collect()[0] |
| assert result.column(2).to_pylist() == ["a", "", "c", ""] |
| |
| |
| def test_fill_null_bool_column(null_df): |
| """Test filling nulls in boolean columns with different values.""" |
| # Fill bool nulls with True |
| filled_df = null_df.fill_null(value=True, subset=["bool_col"]) |
| |
| result = filled_df.collect()[0] |
| |
| # Check that bool nulls were filled with True |
| assert result.column(3).to_pylist() == [True, True, False, True] |
| |
| # Other columns should be unchanged |
| assert None in result.column(0).to_pylist() |
| |
| # Fill bool nulls with False |
| filled_df = null_df.fill_null(value=False, subset=["bool_col"]) |
| result = filled_df.collect()[0] |
| assert result.column(3).to_pylist() == [True, False, False, False] |
| |
| |
| def test_fill_null_date32_column(null_df): |
| """Test filling nulls in date32 columns.""" |
| |
| # Fill date32 nulls with a specific date (1970-01-01) |
| epoch_date = datetime.date(1970, 1, 1) |
| filled_df = null_df.fill_null(epoch_date, subset=["date32_col"]) |
| |
| result = filled_df.collect()[0] |
| |
| # Check that date32 nulls were filled with epoch date |
| dates = result.column(4).to_pylist() |
| assert dates[0] == datetime.date(2000, 1, 1) # Original value |
| assert dates[1] == epoch_date # Filled value |
| assert dates[2] == datetime.date(2022, 1, 1) # Original value |
| assert dates[3] == epoch_date # Filled value |
| |
| # Other date column should be unchanged |
| assert None in result.column(5).to_pylist() |
| |
| |
| def test_fill_null_date64_column(null_df): |
| """Test filling nulls in date64 columns.""" |
| |
| # Fill date64 nulls with a specific date (1970-01-01) |
| epoch_date = datetime.date(1970, 1, 1) |
| filled_df = null_df.fill_null(epoch_date, subset=["date64_col"]) |
| |
| result = filled_df.collect()[0] |
| |
| # Check that date64 nulls were filled with epoch date |
| dates = result.column(5).to_pylist() |
| assert dates[0] == datetime.date(2000, 1, 1) # Original value |
| assert dates[1] == epoch_date # Filled value |
| assert dates[2] == datetime.date(2022, 1, 1) # Original value |
| assert dates[3] == epoch_date # Filled value |
| |
| # Other date column should be unchanged |
| assert None in result.column(4).to_pylist() |
| |
| |
| def test_fill_null_type_coercion(null_df): |
| """Test type coercion when filling nulls with values of different types.""" |
| # Try to fill string nulls with a number |
| filled_df = null_df.fill_null(42, subset=["str_col"]) |
| |
| result = filled_df.collect()[0] |
| |
| # String nulls should be filled with string representation of the number |
| assert result.column(2).to_pylist() == ["a", "42", "c", "42"] |
| |
| # Try to fill bool nulls with a string that converts to True |
| filled_df = null_df.fill_null("true", subset=["bool_col"]) |
| result = filled_df.collect()[0] |
| |
| # This behavior depends on the implementation - check it works without error |
| # but don't make assertions about exact conversion behavior |
| assert None not in result.column(3).to_pylist() |
| |
| |
| def test_fill_null_multiple_date_columns(null_df): |
| """Test filling nulls in both date column types simultaneously.""" |
| |
| # Fill both date column types with the same date |
| test_date = datetime.date(2023, 12, 31) |
| filled_df = null_df.fill_null(test_date, subset=["date32_col", "date64_col"]) |
| |
| result = filled_df.collect()[0] |
| |
| # Check both date columns were filled correctly |
| date32_vals = result.column(4).to_pylist() |
| date64_vals = result.column(5).to_pylist() |
| |
| assert None not in date32_vals |
| assert None not in date64_vals |
| |
| assert date32_vals[1] == test_date |
| assert date32_vals[3] == test_date |
| assert date64_vals[1] == test_date |
| assert date64_vals[3] == test_date |
| |
| |
| def test_fill_null_specific_types(null_df): |
| """Test filling nulls with type-appropriate values.""" |
| # Fill with type-specific values |
| filled_df = null_df.fill_null("missing") |
| |
| result = filled_df.collect()[0] |
| |
| # Check that nulls were filled appropriately by type |
| |
| assert result.column(0).to_pylist() == [1, None, 3, None] |
| assert result.column(1).to_pylist() == [4.5, 6.7, None, None] |
| assert result.column(2).to_pylist() == ["a", "missing", "c", "missing"] |
| assert result.column(3).to_pylist() == [True, None, False, None] # Bool gets False |
| assert result.column(4).to_pylist() == [ |
| datetime.date(2000, 1, 1), |
| None, |
| datetime.date(2022, 1, 1), |
| None, |
| ] |
| assert result.column(5).to_pylist() == [ |
| datetime.date(2000, 1, 1), |
| None, |
| datetime.date(2022, 1, 1), |
| None, |
| ] |
| |
| |
| def test_fill_null_immutability(null_df): |
| """Test that original DataFrame is unchanged after fill_null.""" |
| # Get original values with nulls |
| original = null_df.collect()[0] |
| original_int_nulls = original.column(0).to_pylist().count(None) |
| |
| # Apply fill_null |
| _filled_df = null_df.fill_null(0) |
| |
| # Check that original is unchanged |
| new_original = null_df.collect()[0] |
| new_original_int_nulls = new_original.column(0).to_pylist().count(None) |
| |
| assert original_int_nulls == new_original_int_nulls |
| assert original_int_nulls > 0 # Ensure we actually had nulls in the first place |
| |
| |
| def test_fill_null_empty_df(ctx): |
| """Test fill_null on empty DataFrame.""" |
| # Create an empty DataFrame with schema |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([], type=pa.int64()), pa.array([], type=pa.string())], |
| names=["a", "b"], |
| ) |
| empty_df = ctx.create_dataframe([[batch]]) |
| |
| # Fill nulls (should work without errors) |
| filled_df = empty_df.fill_null(0) |
| |
| # Should still be empty but with same schema |
| result = filled_df.collect()[0] |
| assert len(result.column(0)) == 0 |
| assert len(result.column(1)) == 0 |
| assert result.schema.field(0).name == "a" |
| assert result.schema.field(1).name == "b" |
| |
| |
| def test_fill_null_all_null_column(ctx): |
| """Test fill_null on a column with all nulls.""" |
| # Create DataFrame with a column of all nulls |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([None, None, None], type=pa.string())], |
| names=["a", "b"], |
| ) |
| all_null_df = ctx.create_dataframe([[batch]]) |
| |
| # Fill nulls with a value |
| filled_df = all_null_df.fill_null("filled") |
| |
| # Check that all nulls were filled |
| result = filled_df.collect()[0] |
| assert result.column(1).to_pylist() == ["filled", "filled", "filled"] |
| |
| |
| def test_collect_interrupted(): |
| """Test that a long-running query can be interrupted with Ctrl-C. |
| |
| This test simulates a Ctrl-C keyboard interrupt by raising a KeyboardInterrupt |
| exception in the main thread during a long-running query execution. |
| """ |
| # Create a context and a DataFrame with a query that will run for a while |
| ctx = SessionContext() |
| |
| # Create a recursive computation that will run for some time |
| batches = [] |
| for i in range(10): |
| batch = pa.RecordBatch.from_arrays( |
| [ |
| pa.array(list(range(i * 1000, (i + 1) * 1000))), |
| pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]), |
| ], |
| names=["a", "b"], |
| ) |
| batches.append(batch) |
| |
| # Register tables |
| ctx.register_record_batches("t1", [batches]) |
| ctx.register_record_batches("t2", [batches]) |
| |
| # Create a large join operation that will take time to process |
| df = ctx.sql(""" |
| WITH t1_expanded AS ( |
| SELECT |
| a, |
| b, |
| CAST(a AS DOUBLE) / 1.5 AS c, |
| CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d |
| FROM t1 |
| CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5) |
| ), |
| t2_expanded AS ( |
| SELECT |
| a, |
| b, |
| CAST(a AS DOUBLE) * 2.5 AS e, |
| CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f |
| FROM t2 |
| CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5) |
| ) |
| SELECT |
| t1.a, t1.b, t1.c, t1.d, |
| t2.a AS a2, t2.b AS b2, t2.e, t2.f |
| FROM t1_expanded t1 |
| JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100 |
| WHERE t1.a > 100 AND t2.a > 100 |
| """) |
| |
| # Flag to track if the query was interrupted |
| interrupted = False |
| interrupt_error = None |
| main_thread = threading.main_thread() |
| |
| # Shared flag to indicate query execution has started |
| query_started = threading.Event() |
| max_wait_time = 5.0 # Maximum wait time in seconds |
| |
| # This function will be run in a separate thread and will raise |
| # KeyboardInterrupt in the main thread |
| def trigger_interrupt(): |
| """Poll for query start, then raise KeyboardInterrupt in the main thread""" |
| # Poll for query to start with small sleep intervals |
| start_time = time.time() |
| while not query_started.is_set(): |
| time.sleep(0.1) # Small sleep between checks |
| if time.time() - start_time > max_wait_time: |
| msg = f"Query did not start within {max_wait_time} seconds" |
| raise RuntimeError(msg) |
| |
| # Check if thread ID is available |
| thread_id = main_thread.ident |
| if thread_id is None: |
| msg = "Cannot get main thread ID" |
| raise RuntimeError(msg) |
| |
| # Use ctypes to raise exception in main thread |
| exception = ctypes.py_object(KeyboardInterrupt) |
| res = ctypes.pythonapi.PyThreadState_SetAsyncExc( |
| ctypes.c_long(thread_id), exception |
| ) |
| if res != 1: |
| # If res is 0, the thread ID was invalid |
| # If res > 1, we modified multiple threads |
| ctypes.pythonapi.PyThreadState_SetAsyncExc( |
| ctypes.c_long(thread_id), ctypes.py_object(0) |
| ) |
| msg = "Failed to raise KeyboardInterrupt in main thread" |
| raise RuntimeError(msg) |
| |
| # Start a thread to trigger the interrupt |
| interrupt_thread = threading.Thread(target=trigger_interrupt) |
| # we mark as daemon so the test process can exit even if this thread doesn't finish |
| interrupt_thread.daemon = True |
| interrupt_thread.start() |
| |
| # Execute the query and expect it to be interrupted |
| try: |
| # Signal that we're about to start the query |
| query_started.set() |
| df.collect() |
| except KeyboardInterrupt: |
| interrupted = True |
| except Exception as e: |
| interrupt_error = e |
| |
| # Assert that the query was interrupted properly |
| if not interrupted: |
| pytest.fail(f"Query was not interrupted; got error: {interrupt_error}") |
| |
| # Make sure the interrupt thread has finished |
| interrupt_thread.join(timeout=1.0) |