| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import os |
| import re |
| from typing import Any |
| |
| import pyarrow as pa |
| import pyarrow.parquet as pq |
| import pytest |
| from datafusion import ( |
| DataFrame, |
| SessionContext, |
| WindowFrame, |
| column, |
| literal, |
| ) |
| from datafusion import functions as f |
| from datafusion.expr import Window |
| from pyarrow.csv import write_csv |
| |
| |
| @pytest.fixture |
| def ctx(): |
| return SessionContext() |
| |
| |
| @pytest.fixture |
| def df(): |
| ctx = SessionContext() |
| |
| # create a RecordBatch and a new DataFrame from it |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])], |
| names=["a", "b", "c"], |
| ) |
| |
| return ctx.from_arrow(batch) |
| |
| |
| @pytest.fixture |
| def struct_df(): |
| ctx = SessionContext() |
| |
| # create a RecordBatch and a new DataFrame from it |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([{"c": 1}, {"c": 2}, {"c": 3}]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| |
| return ctx.create_dataframe([[batch]]) |
| |
| |
| @pytest.fixture |
| def nested_df(): |
| ctx = SessionContext() |
| |
| # create a RecordBatch and a new DataFrame from it |
| # Intentionally make each array of different length |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([[1], [2, 3], [4, 5, 6], None]), pa.array([7, 8, 9, 10])], |
| names=["a", "b"], |
| ) |
| |
| return ctx.create_dataframe([[batch]]) |
| |
| |
| @pytest.fixture |
| def aggregate_df(): |
| ctx = SessionContext() |
| ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv") |
| return ctx.sql("select c1, sum(c2) from test group by c1") |
| |
| |
| @pytest.fixture |
| def partitioned_df(): |
| ctx = SessionContext() |
| |
| # create a RecordBatch and a new DataFrame from it |
| batch = pa.RecordBatch.from_arrays( |
| [ |
| pa.array([0, 1, 2, 3, 4, 5, 6]), |
| pa.array([7, None, 7, 8, 9, None, 9]), |
| pa.array(["A", "A", "A", "A", "B", "B", "B"]), |
| ], |
| names=["a", "b", "c"], |
| ) |
| |
| return ctx.create_dataframe([[batch]]) |
| |
| |
| def test_select(df): |
| df_1 = df.select( |
| column("a") + column("b"), |
| column("a") - column("b"), |
| ) |
| |
| # execute and collect the first (and only) batch |
| result = df_1.collect()[0] |
| |
| assert result.column(0) == pa.array([5, 7, 9]) |
| assert result.column(1) == pa.array([-3, -3, -3]) |
| |
| df_2 = df.select("b", "a") |
| |
| # execute and collect the first (and only) batch |
| result = df_2.collect()[0] |
| |
| assert result.column(0) == pa.array([4, 5, 6]) |
| assert result.column(1) == pa.array([1, 2, 3]) |
| |
| |
| def test_select_mixed_expr_string(df): |
| df = df.select(column("b"), "a") |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert result.column(0) == pa.array([4, 5, 6]) |
| assert result.column(1) == pa.array([1, 2, 3]) |
| |
| |
| def test_filter(df): |
| df1 = df.filter(column("a") > literal(2)).select( |
| column("a") + column("b"), |
| column("a") - column("b"), |
| ) |
| |
| # execute and collect the first (and only) batch |
| result = df1.collect()[0] |
| |
| assert result.column(0) == pa.array([9]) |
| assert result.column(1) == pa.array([-3]) |
| |
| df.show() |
| # verify that if there is no filter applied, internal dataframe is unchanged |
| df2 = df.filter() |
| assert df.df == df2.df |
| |
| df3 = df.filter(column("a") > literal(1), column("b") != literal(6)) |
| result = df3.collect()[0] |
| |
| assert result.column(0) == pa.array([2]) |
| assert result.column(1) == pa.array([5]) |
| assert result.column(2) == pa.array([5]) |
| |
| |
| def test_sort(df): |
| df = df.sort(column("b").sort(ascending=False)) |
| |
| table = pa.Table.from_batches(df.collect()) |
| expected = {"a": [3, 2, 1], "b": [6, 5, 4], "c": [8, 5, 8]} |
| |
| assert table.to_pydict() == expected |
| |
| |
| def test_drop(df): |
| df = df.drop("c") |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert df.schema().names == ["a", "b"] |
| assert result.column(0) == pa.array([1, 2, 3]) |
| assert result.column(1) == pa.array([4, 5, 6]) |
| |
| |
| def test_limit(df): |
| df = df.limit(1) |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert len(result.column(0)) == 1 |
| assert len(result.column(1)) == 1 |
| |
| |
| def test_limit_with_offset(df): |
| # only 3 rows, but limit past the end to ensure that offset is working |
| df = df.limit(5, offset=2) |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert len(result.column(0)) == 1 |
| assert len(result.column(1)) == 1 |
| |
| |
| def test_head(df): |
| df = df.head(1) |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert result.column(0) == pa.array([1]) |
| assert result.column(1) == pa.array([4]) |
| assert result.column(2) == pa.array([8]) |
| |
| |
| def test_tail(df): |
| df = df.tail(1) |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert result.column(0) == pa.array([3]) |
| assert result.column(1) == pa.array([6]) |
| assert result.column(2) == pa.array([8]) |
| |
| |
| def test_with_column(df): |
| df = df.with_column("c", column("a") + column("b")) |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert result.schema.field(0).name == "a" |
| assert result.schema.field(1).name == "b" |
| assert result.schema.field(2).name == "c" |
| |
| assert result.column(0) == pa.array([1, 2, 3]) |
| assert result.column(1) == pa.array([4, 5, 6]) |
| assert result.column(2) == pa.array([5, 7, 9]) |
| |
| |
| def test_with_columns(df): |
| df = df.with_columns( |
| (column("a") + column("b")).alias("c"), |
| (column("a") + column("b")).alias("d"), |
| [ |
| (column("a") + column("b")).alias("e"), |
| (column("a") + column("b")).alias("f"), |
| ], |
| g=(column("a") + column("b")), |
| ) |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert result.schema.field(0).name == "a" |
| assert result.schema.field(1).name == "b" |
| assert result.schema.field(2).name == "c" |
| assert result.schema.field(3).name == "d" |
| assert result.schema.field(4).name == "e" |
| assert result.schema.field(5).name == "f" |
| assert result.schema.field(6).name == "g" |
| |
| assert result.column(0) == pa.array([1, 2, 3]) |
| assert result.column(1) == pa.array([4, 5, 6]) |
| assert result.column(2) == pa.array([5, 7, 9]) |
| assert result.column(3) == pa.array([5, 7, 9]) |
| assert result.column(4) == pa.array([5, 7, 9]) |
| assert result.column(5) == pa.array([5, 7, 9]) |
| assert result.column(6) == pa.array([5, 7, 9]) |
| |
| |
| def test_cast(df): |
| df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())}) |
| expected = pa.schema( |
| [("a", pa.float16()), ("b", pa.list_(pa.uint32())), ("c", pa.int64())] |
| ) |
| |
| assert df.schema() == expected |
| |
| |
| def test_with_column_renamed(df): |
| df = df.with_column("c", column("a") + column("b")).with_column_renamed("c", "sum") |
| |
| result = df.collect()[0] |
| |
| assert result.schema.field(0).name == "a" |
| assert result.schema.field(1).name == "b" |
| assert result.schema.field(2).name == "sum" |
| |
| |
| def test_unnest(nested_df): |
| nested_df = nested_df.unnest_columns("a") |
| |
| # execute and collect the first (and only) batch |
| result = nested_df.collect()[0] |
| |
| assert result.column(0) == pa.array([1, 2, 3, 4, 5, 6, None]) |
| assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9, 10]) |
| |
| |
| def test_unnest_without_nulls(nested_df): |
| nested_df = nested_df.unnest_columns("a", preserve_nulls=False) |
| |
| # execute and collect the first (and only) batch |
| result = nested_df.collect()[0] |
| |
| assert result.column(0) == pa.array([1, 2, 3, 4, 5, 6]) |
| assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9]) |
| |
| |
| @pytest.mark.filterwarnings("ignore:`join_keys`:DeprecationWarning") |
| def test_join(): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df = ctx.create_dataframe([[batch]], "l") |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2]), pa.array([8, 10])], |
| names=["a", "c"], |
| ) |
| df1 = ctx.create_dataframe([[batch]], "r") |
| |
| df2 = df.join(df1, on="a", how="inner") |
| df2.show() |
| df2 = df2.sort(column("l.a")) |
| table = pa.Table.from_batches(df2.collect()) |
| |
| expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} |
| assert table.to_pydict() == expected |
| |
| df2 = df.join(df1, left_on="a", right_on="a", how="inner") |
| df2.show() |
| df2 = df2.sort(column("l.a")) |
| table = pa.Table.from_batches(df2.collect()) |
| |
| expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} |
| assert table.to_pydict() == expected |
| |
| # Verify we don't make a breaking change to pre-43.0.0 |
| # where users would pass join_keys as a positional argument |
| df2 = df.join(df1, (["a"], ["a"]), how="inner") |
| df2.show() |
| df2 = df2.sort(column("l.a")) |
| table = pa.Table.from_batches(df2.collect()) |
| |
| expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} |
| assert table.to_pydict() == expected |
| |
| |
| def test_join_invalid_params(): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df = ctx.create_dataframe([[batch]], "l") |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2]), pa.array([8, 10])], |
| names=["a", "c"], |
| ) |
| df1 = ctx.create_dataframe([[batch]], "r") |
| |
| with pytest.deprecated_call(): |
| df2 = df.join(df1, join_keys=(["a"], ["a"]), how="inner") |
| df2.show() |
| df2 = df2.sort(column("l.a")) |
| table = pa.Table.from_batches(df2.collect()) |
| |
| expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} |
| assert table.to_pydict() == expected |
| |
| with pytest.raises( |
| ValueError, match=r"`left_on` or `right_on` should not provided with `on`" |
| ): |
| df2 = df.join(df1, on="a", how="inner", right_on="test") |
| |
| with pytest.raises( |
| ValueError, match=r"`left_on` and `right_on` should both be provided." |
| ): |
| df2 = df.join(df1, left_on="a", how="inner") |
| |
| with pytest.raises( |
| ValueError, match=r"either `on` or `left_on` and `right_on` should be provided." |
| ): |
| df2 = df.join(df1, how="inner") |
| |
| |
| def test_join_on(): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df = ctx.create_dataframe([[batch]], "l") |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2]), pa.array([-8, 10])], |
| names=["a", "c"], |
| ) |
| df1 = ctx.create_dataframe([[batch]], "r") |
| |
| df2 = df.join_on(df1, column("l.a").__eq__(column("r.a")), how="inner") |
| df2.show() |
| df2 = df2.sort(column("l.a")) |
| table = pa.Table.from_batches(df2.collect()) |
| |
| expected = {"a": [1, 2], "c": [-8, 10], "b": [4, 5]} |
| assert table.to_pydict() == expected |
| |
| df3 = df.join_on( |
| df1, |
| column("l.a").__eq__(column("r.a")), |
| column("l.a").__lt__(column("r.c")), |
| how="inner", |
| ) |
| df3.show() |
| df3 = df3.sort(column("l.a")) |
| table = pa.Table.from_batches(df3.collect()) |
| expected = {"a": [2], "c": [10], "b": [5]} |
| assert table.to_pydict() == expected |
| |
| |
| def test_distinct(): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3, 1, 2, 3]), pa.array([4, 5, 6, 4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df_a = ctx.create_dataframe([[batch]]).distinct().sort(column("a")) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df_b = ctx.create_dataframe([[batch]]).sort(column("a")) |
| |
| assert df_a.collect() == df_b.collect() |
| |
| |
| data_test_window_functions = [ |
| ( |
| "row", |
| f.row_number(order_by=[column("b"), column("a").sort(ascending=False)]), |
| [4, 2, 3, 5, 7, 1, 6], |
| ), |
| ( |
| "row_w_params", |
| f.row_number( |
| order_by=[column("b"), column("a")], |
| partition_by=[column("c")], |
| ), |
| [2, 1, 3, 4, 2, 1, 3], |
| ), |
| ("rank", f.rank(order_by=[column("b")]), [3, 1, 3, 5, 6, 1, 6]), |
| ( |
| "rank_w_params", |
| f.rank(order_by=[column("b"), column("a")], partition_by=[column("c")]), |
| [2, 1, 3, 4, 2, 1, 3], |
| ), |
| ( |
| "dense_rank", |
| f.dense_rank(order_by=[column("b")]), |
| [2, 1, 2, 3, 4, 1, 4], |
| ), |
| ( |
| "dense_rank_w_params", |
| f.dense_rank(order_by=[column("b"), column("a")], partition_by=[column("c")]), |
| [2, 1, 3, 4, 2, 1, 3], |
| ), |
| ( |
| "percent_rank", |
| f.round(f.percent_rank(order_by=[column("b")]), literal(3)), |
| [0.333, 0.0, 0.333, 0.667, 0.833, 0.0, 0.833], |
| ), |
| ( |
| "percent_rank_w_params", |
| f.round( |
| f.percent_rank( |
| order_by=[column("b"), column("a")], partition_by=[column("c")] |
| ), |
| literal(3), |
| ), |
| [0.333, 0.0, 0.667, 1.0, 0.5, 0.0, 1.0], |
| ), |
| ( |
| "cume_dist", |
| f.round(f.cume_dist(order_by=[column("b")]), literal(3)), |
| [0.571, 0.286, 0.571, 0.714, 1.0, 0.286, 1.0], |
| ), |
| ( |
| "cume_dist_w_params", |
| f.round( |
| f.cume_dist( |
| order_by=[column("b"), column("a")], partition_by=[column("c")] |
| ), |
| literal(3), |
| ), |
| [0.5, 0.25, 0.75, 1.0, 0.667, 0.333, 1.0], |
| ), |
| ( |
| "ntile", |
| f.ntile(2, order_by=[column("b")]), |
| [1, 1, 1, 2, 2, 1, 2], |
| ), |
| ( |
| "ntile_w_params", |
| f.ntile(2, order_by=[column("b"), column("a")], partition_by=[column("c")]), |
| [1, 1, 2, 2, 1, 1, 2], |
| ), |
| ("lead", f.lead(column("b"), order_by=[column("b")]), [7, None, 8, 9, 9, 7, None]), |
| ( |
| "lead_w_params", |
| f.lead( |
| column("b"), |
| shift_offset=2, |
| default_value=-1, |
| order_by=[column("b"), column("a")], |
| partition_by=[column("c")], |
| ), |
| [8, 7, -1, -1, -1, 9, -1], |
| ), |
| ("lag", f.lag(column("b"), order_by=[column("b")]), [None, None, 7, 7, 8, None, 9]), |
| ( |
| "lag_w_params", |
| f.lag( |
| column("b"), |
| shift_offset=2, |
| default_value=-1, |
| order_by=[column("b"), column("a")], |
| partition_by=[column("c")], |
| ), |
| [-1, -1, None, 7, -1, -1, None], |
| ), |
| ( |
| "first_value", |
| f.first_value(column("a")).over( |
| Window(partition_by=[column("c")], order_by=[column("b")]) |
| ), |
| [1, 1, 1, 1, 5, 5, 5], |
| ), |
| ( |
| "last_value", |
| f.last_value(column("a")).over( |
| Window( |
| partition_by=[column("c")], |
| order_by=[column("b")], |
| window_frame=WindowFrame("rows", None, None), |
| ) |
| ), |
| [3, 3, 3, 3, 6, 6, 6], |
| ), |
| ( |
| "3rd_value", |
| f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])), |
| [None, None, 7, 7, 7, 7, 7], |
| ), |
| ( |
| "avg", |
| f.round(f.avg(column("b")).over(Window(order_by=[column("a")])), literal(3)), |
| [7.0, 7.0, 7.0, 7.333, 7.75, 7.75, 8.0], |
| ), |
| ] |
| |
| |
| @pytest.mark.parametrize(("name", "expr", "result"), data_test_window_functions) |
| def test_window_functions(partitioned_df, name, expr, result): |
| df = partitioned_df.select( |
| column("a"), column("b"), column("c"), f.alias(expr, name) |
| ) |
| df.sort(column("a")).show() |
| table = pa.Table.from_batches(df.collect()) |
| |
| expected = { |
| "a": [0, 1, 2, 3, 4, 5, 6], |
| "b": [7, None, 7, 8, 9, None, 9], |
| "c": ["A", "A", "A", "A", "B", "B", "B"], |
| name: result, |
| } |
| |
| assert table.sort_by("a").to_pydict() == expected |
| |
| |
| @pytest.mark.parametrize( |
| ("units", "start_bound", "end_bound"), |
| [ |
| (units, start_bound, end_bound) |
| for units in ("rows", "range") |
| for start_bound in (None, 0, 1) |
| for end_bound in (None, 0, 1) |
| ] |
| + [ |
| ("groups", 0, 0), |
| ], |
| ) |
| def test_valid_window_frame(units, start_bound, end_bound): |
| WindowFrame(units, start_bound, end_bound) |
| |
| |
| @pytest.mark.parametrize( |
| ("units", "start_bound", "end_bound"), |
| [ |
| ("invalid-units", 0, None), |
| ("invalid-units", None, 0), |
| ("invalid-units", None, None), |
| ("groups", None, 0), |
| ("groups", 0, None), |
| ("groups", None, None), |
| ], |
| ) |
| def test_invalid_window_frame(units, start_bound, end_bound): |
| with pytest.raises(RuntimeError): |
| WindowFrame(units, start_bound, end_bound) |
| |
| |
| def test_window_frame_defaults_match_postgres(partitioned_df): |
| # ref: https://github.com/apache/datafusion-python/issues/688 |
| |
| window_frame = WindowFrame("rows", None, None) |
| |
| col_a = column("a") |
| |
| # Using `f.window` with or without an unbounded window_frame produces the same |
| # results. These tests are included as a regression check but can be removed when |
| # f.window() is deprecated in favor of using the .over() approach. |
| no_frame = f.window("avg", [col_a]).alias("no_frame") |
| with_frame = f.window("avg", [col_a], window_frame=window_frame).alias("with_frame") |
| df_1 = partitioned_df.select(col_a, no_frame, with_frame) |
| |
| expected = { |
| "a": [0, 1, 2, 3, 4, 5, 6], |
| "no_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], |
| "with_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], |
| } |
| |
| assert df_1.sort(col_a).to_pydict() == expected |
| |
| # When order is not set, the default frame should be unounded preceeding to |
| # unbounded following. When order is set, the default frame is unbounded preceeding |
| # to current row. |
| no_order = f.avg(col_a).over(Window()).alias("over_no_order") |
| with_order = f.avg(col_a).over(Window(order_by=[col_a])).alias("over_with_order") |
| df_2 = partitioned_df.select(col_a, no_order, with_order) |
| |
| expected = { |
| "a": [0, 1, 2, 3, 4, 5, 6], |
| "over_no_order": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], |
| "over_with_order": [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0], |
| } |
| |
| assert df_2.sort(col_a).to_pydict() == expected |
| |
| |
| def test_get_dataframe(tmp_path): |
| ctx = SessionContext() |
| |
| path = tmp_path / "test.csv" |
| table = pa.Table.from_arrays( |
| [ |
| [1, 2, 3, 4], |
| ["a", "b", "c", "d"], |
| [1.1, 2.2, 3.3, 4.4], |
| ], |
| names=["int", "str", "float"], |
| ) |
| write_csv(table, path) |
| |
| ctx.register_csv("csv", path) |
| |
| df = ctx.table("csv") |
| assert isinstance(df, DataFrame) |
| |
| |
| def test_struct_select(struct_df): |
| df = struct_df.select( |
| column("a")["c"] + column("b"), |
| column("a")["c"] - column("b"), |
| ) |
| |
| # execute and collect the first (and only) batch |
| result = df.collect()[0] |
| |
| assert result.column(0) == pa.array([5, 7, 9]) |
| assert result.column(1) == pa.array([-3, -3, -3]) |
| |
| |
| def test_explain(df): |
| df = df.select( |
| column("a") + column("b"), |
| column("a") - column("b"), |
| ) |
| df.explain() |
| |
| |
| def test_logical_plan(aggregate_df): |
| plan = aggregate_df.logical_plan() |
| |
| expected = "Projection: test.c1, sum(test.c2)" |
| |
| assert expected == plan.display() |
| |
| expected = ( |
| "Projection: test.c1, sum(test.c2)\n" |
| " Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]\n" |
| " TableScan: test" |
| ) |
| |
| assert expected == plan.display_indent() |
| |
| |
| def test_optimized_logical_plan(aggregate_df): |
| plan = aggregate_df.optimized_logical_plan() |
| |
| expected = "Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]" |
| |
| assert expected == plan.display() |
| |
| expected = ( |
| "Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]\n" |
| " TableScan: test projection=[c1, c2]" |
| ) |
| |
| assert expected == plan.display_indent() |
| |
| |
| def test_execution_plan(aggregate_df): |
| plan = aggregate_df.execution_plan() |
| |
| expected = ( |
| "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n" |
| ) |
| |
| assert expected == plan.display() |
| |
| # Check the number of partitions is as expected. |
| assert isinstance(plan.partition_count, int) |
| |
| expected = ( |
| "ProjectionExec: expr=[c1@0 as c1, SUM(test.c2)@1 as SUM(test.c2)]\n" |
| " Aggregate: groupBy=[[test.c1]], aggr=[[SUM(test.c2)]]\n" |
| " TableScan: test projection=[c1, c2]" |
| ) |
| |
| indent = plan.display_indent() |
| |
| # indent plan will be different for everyone due to absolute path |
| # to filename, so we just check for some expected content |
| assert "AggregateExec:" in indent |
| assert "CoalesceBatchesExec:" in indent |
| assert "RepartitionExec:" in indent |
| assert "DataSourceExec:" in indent |
| assert "file_type=csv" in indent |
| |
| ctx = SessionContext() |
| rows_returned = 0 |
| for idx in range(plan.partition_count): |
| stream = ctx.execute(plan, idx) |
| try: |
| batch = stream.next() |
| assert batch is not None |
| rows_returned += len(batch.to_pyarrow()[0]) |
| except StopIteration: |
| # This is one of the partitions with no values |
| pass |
| with pytest.raises(StopIteration): |
| stream.next() |
| |
| assert rows_returned == 5 |
| |
| |
| @pytest.mark.asyncio |
| async def test_async_iteration_of_df(aggregate_df): |
| rows_returned = 0 |
| async for batch in aggregate_df.execute_stream(): |
| assert batch is not None |
| rows_returned += len(batch.to_pyarrow()[0]) |
| |
| assert rows_returned == 5 |
| |
| |
| def test_repartition(df): |
| df.repartition(2) |
| |
| |
| def test_repartition_by_hash(df): |
| df.repartition_by_hash(column("a"), num=2) |
| |
| |
| def test_intersect(): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df_a = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([3, 4, 5]), pa.array([6, 7, 8])], |
| names=["a", "b"], |
| ) |
| df_b = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([3]), pa.array([6])], |
| names=["a", "b"], |
| ) |
| df_c = ctx.create_dataframe([[batch]]).sort(column("a")) |
| |
| df_a_i_b = df_a.intersect(df_b).sort(column("a")) |
| |
| assert df_c.collect() == df_a_i_b.collect() |
| |
| |
| def test_except_all(): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df_a = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([3, 4, 5]), pa.array([6, 7, 8])], |
| names=["a", "b"], |
| ) |
| df_b = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2]), pa.array([4, 5])], |
| names=["a", "b"], |
| ) |
| df_c = ctx.create_dataframe([[batch]]).sort(column("a")) |
| |
| df_a_e_b = df_a.except_all(df_b).sort(column("a")) |
| |
| assert df_c.collect() == df_a_e_b.collect() |
| |
| |
| def test_collect_partitioned(): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| |
| assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned() |
| |
| |
| def test_union(ctx): |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df_a = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([3, 4, 5]), pa.array([6, 7, 8])], |
| names=["a", "b"], |
| ) |
| df_b = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3, 3, 4, 5]), pa.array([4, 5, 6, 6, 7, 8])], |
| names=["a", "b"], |
| ) |
| df_c = ctx.create_dataframe([[batch]]).sort(column("a")) |
| |
| df_a_u_b = df_a.union(df_b).sort(column("a")) |
| |
| assert df_c.collect() == df_a_u_b.collect() |
| |
| |
| def test_union_distinct(ctx): |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| df_a = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([3, 4, 5]), pa.array([6, 7, 8])], |
| names=["a", "b"], |
| ) |
| df_b = ctx.create_dataframe([[batch]]) |
| |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3, 4, 5]), pa.array([4, 5, 6, 7, 8])], |
| names=["a", "b"], |
| ) |
| df_c = ctx.create_dataframe([[batch]]).sort(column("a")) |
| |
| df_a_u_b = df_a.union(df_b, distinct=True).sort(column("a")) |
| |
| assert df_c.collect() == df_a_u_b.collect() |
| assert df_c.collect() == df_a_u_b.collect() |
| |
| |
| def test_cache(df): |
| assert df.cache().collect() == df.collect() |
| |
| |
| def test_count(df): |
| # Get number of rows |
| assert df.count() == 3 |
| |
| |
| def test_to_pandas(df): |
| # Skip test if pandas is not installed |
| pd = pytest.importorskip("pandas") |
| |
| # Convert datafusion dataframe to pandas dataframe |
| pandas_df = df.to_pandas() |
| assert isinstance(pandas_df, pd.DataFrame) |
| assert pandas_df.shape == (3, 3) |
| assert set(pandas_df.columns) == {"a", "b", "c"} |
| |
| |
| def test_empty_to_pandas(df): |
| # Skip test if pandas is not installed |
| pd = pytest.importorskip("pandas") |
| |
| # Convert empty datafusion dataframe to pandas dataframe |
| pandas_df = df.limit(0).to_pandas() |
| assert isinstance(pandas_df, pd.DataFrame) |
| assert pandas_df.shape == (0, 3) |
| assert set(pandas_df.columns) == {"a", "b", "c"} |
| |
| |
| def test_to_polars(df): |
| # Skip test if polars is not installed |
| pl = pytest.importorskip("polars") |
| |
| # Convert datafusion dataframe to polars dataframe |
| polars_df = df.to_polars() |
| assert isinstance(polars_df, pl.DataFrame) |
| assert polars_df.shape == (3, 3) |
| assert set(polars_df.columns) == {"a", "b", "c"} |
| |
| |
| def test_empty_to_polars(df): |
| # Skip test if polars is not installed |
| pl = pytest.importorskip("polars") |
| |
| # Convert empty datafusion dataframe to polars dataframe |
| polars_df = df.limit(0).to_polars() |
| assert isinstance(polars_df, pl.DataFrame) |
| assert polars_df.shape == (0, 3) |
| assert set(polars_df.columns) == {"a", "b", "c"} |
| |
| |
| def test_to_arrow_table(df): |
| # Convert datafusion dataframe to pyarrow Table |
| pyarrow_table = df.to_arrow_table() |
| assert isinstance(pyarrow_table, pa.Table) |
| assert pyarrow_table.shape == (3, 3) |
| assert set(pyarrow_table.column_names) == {"a", "b", "c"} |
| |
| |
| def test_execute_stream(df): |
| stream = df.execute_stream() |
| assert all(batch is not None for batch in stream) |
| assert not list(stream) # after one iteration the generator must be exhausted |
| |
| |
| @pytest.mark.asyncio |
| async def test_execute_stream_async(df): |
| stream = df.execute_stream() |
| batches = [batch async for batch in stream] |
| |
| assert all(batch is not None for batch in batches) |
| |
| # After consuming all batches, the stream should be exhausted |
| remaining_batches = [batch async for batch in stream] |
| assert not remaining_batches |
| |
| |
| @pytest.mark.parametrize("schema", [True, False]) |
| def test_execute_stream_to_arrow_table(df, schema): |
| stream = df.execute_stream() |
| |
| if schema: |
| pyarrow_table = pa.Table.from_batches( |
| (batch.to_pyarrow() for batch in stream), schema=df.schema() |
| ) |
| else: |
| pyarrow_table = pa.Table.from_batches(batch.to_pyarrow() for batch in stream) |
| |
| assert isinstance(pyarrow_table, pa.Table) |
| assert pyarrow_table.shape == (3, 3) |
| assert set(pyarrow_table.column_names) == {"a", "b", "c"} |
| |
| |
| @pytest.mark.asyncio |
| @pytest.mark.parametrize("schema", [True, False]) |
| async def test_execute_stream_to_arrow_table_async(df, schema): |
| stream = df.execute_stream() |
| |
| if schema: |
| pyarrow_table = pa.Table.from_batches( |
| [batch.to_pyarrow() async for batch in stream], schema=df.schema() |
| ) |
| else: |
| pyarrow_table = pa.Table.from_batches( |
| [batch.to_pyarrow() async for batch in stream] |
| ) |
| |
| assert isinstance(pyarrow_table, pa.Table) |
| assert pyarrow_table.shape == (3, 3) |
| assert set(pyarrow_table.column_names) == {"a", "b", "c"} |
| |
| |
| def test_execute_stream_partitioned(df): |
| streams = df.execute_stream_partitioned() |
| assert all(batch is not None for stream in streams for batch in stream) |
| assert all( |
| not list(stream) for stream in streams |
| ) # after one iteration all generators must be exhausted |
| |
| |
| @pytest.mark.asyncio |
| async def test_execute_stream_partitioned_async(df): |
| streams = df.execute_stream_partitioned() |
| |
| for stream in streams: |
| batches = [batch async for batch in stream] |
| assert all(batch is not None for batch in batches) |
| |
| # Ensure the stream is exhausted after iteration |
| remaining_batches = [batch async for batch in stream] |
| assert not remaining_batches |
| |
| |
| def test_empty_to_arrow_table(df): |
| # Convert empty datafusion dataframe to pyarrow Table |
| pyarrow_table = df.limit(0).to_arrow_table() |
| assert isinstance(pyarrow_table, pa.Table) |
| assert pyarrow_table.shape == (0, 3) |
| assert set(pyarrow_table.column_names) == {"a", "b", "c"} |
| |
| |
| def test_to_pylist(df): |
| # Convert datafusion dataframe to Python list |
| pylist = df.to_pylist() |
| assert isinstance(pylist, list) |
| assert pylist == [ |
| {"a": 1, "b": 4, "c": 8}, |
| {"a": 2, "b": 5, "c": 5}, |
| {"a": 3, "b": 6, "c": 8}, |
| ] |
| |
| |
| def test_to_pydict(df): |
| # Convert datafusion dataframe to Python dictionary |
| pydict = df.to_pydict() |
| assert isinstance(pydict, dict) |
| assert pydict == {"a": [1, 2, 3], "b": [4, 5, 6], "c": [8, 5, 8]} |
| |
| |
| def test_describe(df): |
| # Calculate statistics |
| df = df.describe() |
| |
| # Collect the result |
| result = df.to_pydict() |
| |
| assert result == { |
| "describe": [ |
| "count", |
| "null_count", |
| "mean", |
| "std", |
| "min", |
| "max", |
| "median", |
| ], |
| "a": [3.0, 0.0, 2.0, 1.0, 1.0, 3.0, 2.0], |
| "b": [3.0, 0.0, 5.0, 1.0, 4.0, 6.0, 5.0], |
| "c": [3.0, 0.0, 7.0, 1.7320508075688772, 5.0, 8.0, 8.0], |
| } |
| |
| |
| @pytest.mark.parametrize("path_to_str", [True, False]) |
| def test_write_csv(ctx, df, tmp_path, path_to_str): |
| path = str(tmp_path) if path_to_str else tmp_path |
| |
| df.write_csv(path, with_header=True) |
| |
| ctx.register_csv("csv", path) |
| result = ctx.table("csv").to_pydict() |
| expected = df.to_pydict() |
| |
| assert result == expected |
| |
| |
| @pytest.mark.parametrize("path_to_str", [True, False]) |
| def test_write_json(ctx, df, tmp_path, path_to_str): |
| path = str(tmp_path) if path_to_str else tmp_path |
| |
| df.write_json(path) |
| |
| ctx.register_json("json", path) |
| result = ctx.table("json").to_pydict() |
| expected = df.to_pydict() |
| |
| assert result == expected |
| |
| |
| @pytest.mark.parametrize("path_to_str", [True, False]) |
| def test_write_parquet(df, tmp_path, path_to_str): |
| path = str(tmp_path) if path_to_str else tmp_path |
| |
| df.write_parquet(str(path)) |
| result = pq.read_table(str(path)).to_pydict() |
| expected = df.to_pydict() |
| |
| assert result == expected |
| |
| |
| @pytest.mark.parametrize( |
| ("compression", "compression_level"), |
| [("gzip", 6), ("brotli", 7), ("zstd", 15)], |
| ) |
| def test_write_compressed_parquet(df, tmp_path, compression, compression_level): |
| path = tmp_path |
| |
| df.write_parquet( |
| str(path), compression=compression, compression_level=compression_level |
| ) |
| |
| # test that the actual compression scheme is the one written |
| for _root, _dirs, files in os.walk(path): |
| for file in files: |
| if file.endswith(".parquet"): |
| metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict() |
| for row_group in metadata["row_groups"]: |
| for columns in row_group["columns"]: |
| assert columns["compression"].lower() == compression |
| |
| result = pq.read_table(str(path)).to_pydict() |
| expected = df.to_pydict() |
| |
| assert result == expected |
| |
| |
| @pytest.mark.parametrize( |
| ("compression", "compression_level"), |
| [("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)], |
| ) |
| def test_write_compressed_parquet_wrong_compression_level( |
| df, tmp_path, compression, compression_level |
| ): |
| path = tmp_path |
| |
| with pytest.raises(ValueError): |
| df.write_parquet( |
| str(path), |
| compression=compression, |
| compression_level=compression_level, |
| ) |
| |
| |
| @pytest.mark.parametrize("compression", ["wrong"]) |
| def test_write_compressed_parquet_invalid_compression(df, tmp_path, compression): |
| path = tmp_path |
| |
| with pytest.raises(ValueError): |
| df.write_parquet(str(path), compression=compression) |
| |
| |
| # not testing lzo because it it not implemented yet |
| # https://github.com/apache/arrow-rs/issues/6970 |
| @pytest.mark.parametrize("compression", ["zstd", "brotli", "gzip"]) |
| def test_write_compressed_parquet_default_compression_level(df, tmp_path, compression): |
| # Test write_parquet with zstd, brotli, gzip default compression level, |
| # ie don't specify compression level |
| # should complete without error |
| path = tmp_path |
| |
| df.write_parquet(str(path), compression=compression) |
| |
| |
| def test_dataframe_export(df) -> None: |
| # Guarantees that we have the canonical implementation |
| # reading our dataframe export |
| table = pa.table(df) |
| assert table.num_columns == 3 |
| assert table.num_rows == 3 |
| |
| desired_schema = pa.schema([("a", pa.int64())]) |
| |
| # Verify we can request a schema |
| table = pa.table(df, schema=desired_schema) |
| assert table.num_columns == 1 |
| assert table.num_rows == 3 |
| |
| # Expect a table of nulls if the schema don't overlap |
| desired_schema = pa.schema([("g", pa.string())]) |
| table = pa.table(df, schema=desired_schema) |
| assert table.num_columns == 1 |
| assert table.num_rows == 3 |
| for i in range(3): |
| assert table[0][i].as_py() is None |
| |
| # Expect an error when we cannot convert schema |
| desired_schema = pa.schema([("a", pa.float32())]) |
| failed_convert = False |
| try: |
| table = pa.table(df, schema=desired_schema) |
| except Exception: |
| failed_convert = True |
| assert failed_convert |
| |
| # Expect an error when we have a not set non-nullable |
| desired_schema = pa.schema([("g", pa.string(), False)]) |
| failed_convert = False |
| try: |
| table = pa.table(df, schema=desired_schema) |
| except Exception: |
| failed_convert = True |
| assert failed_convert |
| |
| |
| def test_dataframe_transform(df): |
| def add_string_col(df_internal) -> DataFrame: |
| return df_internal.with_column("string_col", literal("string data")) |
| |
| def add_with_parameter(df_internal, value: Any) -> DataFrame: |
| return df_internal.with_column("new_col", literal(value)) |
| |
| df = df.transform(add_string_col).transform(add_with_parameter, 3) |
| |
| result = df.to_pydict() |
| |
| assert result["a"] == [1, 2, 3] |
| assert result["string_col"] == ["string data" for _i in range(3)] |
| assert result["new_col"] == [3 for _i in range(3)] |
| |
| |
| def test_dataframe_repr_html(df) -> None: |
| output = df._repr_html_() |
| |
| # Since we've added a fair bit of processing to the html output, lets just verify |
| # the values we are expecting in the table exist. Use regex and ignore everything |
| # between the <th></th> and <td></td>. We also don't want the closing > on the |
| # td and th segments because that is where the formatting data is written. |
| |
| headers = ["a", "b", "c"] |
| headers = [f"<th(.*?)>{v}</th>" for v in headers] |
| header_pattern = "(.*?)".join(headers) |
| assert len(re.findall(header_pattern, output, re.DOTALL)) == 1 |
| |
| body_data = [[1, 4, 8], [2, 5, 5], [3, 6, 8]] |
| body_lines = [f"<td(.*?)>{v}</td>" for inner in body_data for v in inner] |
| body_pattern = "(.*?)".join(body_lines) |
| assert len(re.findall(body_pattern, output, re.DOTALL)) == 1 |