| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import pyarrow as pa |
| import pytest |
| from datafusion import SessionContext, col |
| from datafusion.expr import ( |
| Aggregate, |
| AggregateFunction, |
| BinaryExpr, |
| Column, |
| Filter, |
| Limit, |
| Literal, |
| Projection, |
| Sort, |
| TableScan, |
| ) |
| |
| |
| @pytest.fixture |
| def test_ctx(): |
| ctx = SessionContext() |
| ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv") |
| return ctx |
| |
| |
| def test_projection(test_ctx): |
| df = test_ctx.sql("select c1, 123, c1 < 123 from test") |
| plan = df.logical_plan() |
| |
| plan = plan.to_variant() |
| assert isinstance(plan, Projection) |
| |
| expr = plan.projections() |
| |
| col1 = expr[0].to_variant() |
| assert isinstance(col1, Column) |
| assert col1.name() == "c1" |
| assert col1.qualified_name() == "test.c1" |
| |
| col2 = expr[1].to_variant() |
| assert isinstance(col2, Literal) |
| assert col2.data_type() == "Int64" |
| assert col2.value_i64() == 123 |
| |
| col3 = expr[2].to_variant() |
| assert isinstance(col3, BinaryExpr) |
| assert isinstance(col3.left().to_variant(), Column) |
| assert col3.op() == "<" |
| assert isinstance(col3.right().to_variant(), Literal) |
| |
| plan = plan.input()[0].to_variant() |
| assert isinstance(plan, TableScan) |
| |
| |
| def test_filter(test_ctx): |
| df = test_ctx.sql("select c1 from test WHERE c1 > 5") |
| plan = df.logical_plan() |
| |
| plan = plan.to_variant() |
| assert isinstance(plan, Projection) |
| |
| plan = plan.input()[0].to_variant() |
| assert isinstance(plan, Filter) |
| |
| |
| def test_limit(test_ctx): |
| df = test_ctx.sql("select c1 from test LIMIT 10") |
| plan = df.logical_plan() |
| |
| plan = plan.to_variant() |
| assert isinstance(plan, Limit) |
| # TODO: Upstream now has expressions for skip and fetch |
| # REF: https://github.com/apache/datafusion/pull/12836 |
| # assert plan.skip() == 0 |
| |
| df = test_ctx.sql("select c1 from test LIMIT 10 OFFSET 5") |
| plan = df.logical_plan() |
| |
| plan = plan.to_variant() |
| assert isinstance(plan, Limit) |
| # TODO: Upstream now has expressions for skip and fetch |
| # REF: https://github.com/apache/datafusion/pull/12836 |
| # assert plan.skip() == 5 |
| |
| |
| def test_aggregate_query(test_ctx): |
| df = test_ctx.sql("select c1, count(*) from test group by c1") |
| plan = df.logical_plan() |
| |
| projection = plan.to_variant() |
| assert isinstance(projection, Projection) |
| |
| aggregate = projection.input()[0].to_variant() |
| assert isinstance(aggregate, Aggregate) |
| |
| col1 = aggregate.group_by_exprs()[0].to_variant() |
| assert isinstance(col1, Column) |
| assert col1.name() == "c1" |
| assert col1.qualified_name() == "test.c1" |
| |
| col2 = aggregate.aggregate_exprs()[0].to_variant() |
| assert isinstance(col2, AggregateFunction) |
| |
| |
| def test_sort(test_ctx): |
| df = test_ctx.sql("select c1 from test order by c1") |
| plan = df.logical_plan() |
| |
| plan = plan.to_variant() |
| assert isinstance(plan, Sort) |
| |
| |
| def test_relational_expr(test_ctx): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [ |
| pa.array([1, 2, 3]), |
| pa.array(["alpha", "beta", "gamma"], type=pa.string_view()), |
| ], |
| names=["a", "b"], |
| ) |
| df = ctx.create_dataframe([[batch]], name="batch_array") |
| |
| assert df.filter(col("a") == 1).count() == 1 |
| assert df.filter(col("a") != 1).count() == 2 |
| assert df.filter(col("a") >= 1).count() == 3 |
| assert df.filter(col("a") > 1).count() == 2 |
| assert df.filter(col("a") <= 3).count() == 3 |
| assert df.filter(col("a") < 3).count() == 2 |
| |
| assert df.filter(col("b") == "beta").count() == 1 |
| assert df.filter(col("b") != "beta").count() == 2 |
| |
| with pytest.raises(Exception): |
| df.filter(col("a") == "beta").count() |
| |
| |
| def test_expr_to_variant(): |
| # Taken from https://github.com/apache/datafusion-python/issues/781 |
| from datafusion import SessionContext |
| from datafusion.expr import Filter |
| |
| def traverse_logical_plan(plan): |
| cur_node = plan.to_variant() |
| if isinstance(cur_node, Filter): |
| return cur_node.predicate().to_variant() |
| if hasattr(plan, "inputs"): |
| for input_plan in plan.inputs(): |
| res = traverse_logical_plan(input_plan) |
| if res is not None: |
| return res |
| |
| ctx = SessionContext() |
| data = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"]} |
| ctx.from_pydict(data, name="table1") |
| query = "SELECT * FROM table1 t1 WHERE t1.name IN ('dfa', 'ad', 'dfre', 'vsa')" |
| logical_plan = ctx.sql(query).optimized_logical_plan() |
| variant = traverse_logical_plan(logical_plan) |
| assert variant is not None |
| assert variant.expr().to_variant().qualified_name() == "table1.name" |
| assert ( |
| str(variant.list()) |
| == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]' |
| ) |
| assert not variant.negated() |
| |
| |
| def test_expr_getitem() -> None: |
| ctx = SessionContext() |
| data = { |
| "array_values": [[1, 2, 3], [4, 5], [6], []], |
| "struct_values": [ |
| {"name": "Alice", "age": 15}, |
| {"name": "Bob", "age": 14}, |
| {"name": "Charlie", "age": 13}, |
| {"name": None, "age": 12}, |
| ], |
| } |
| df = ctx.from_pydict(data, name="table1") |
| |
| names = df.select(col("struct_values")["name"].alias("name")).collect() |
| names = [r.as_py() for rs in names for r in rs["name"]] |
| |
| array_values = df.select(col("array_values")[1].alias("value")).collect() |
| array_values = [r.as_py() for rs in array_values for r in rs["value"]] |
| |
| assert names == ["Alice", "Bob", "Charlie", None] |
| assert array_values == [2, 5, None, None] |
| |
| |
| def test_display_name_deprecation(): |
| import warnings |
| |
| expr = col("foo") |
| with warnings.catch_warnings(record=True) as w: |
| # Cause all warnings to always be triggered |
| warnings.simplefilter("always") |
| |
| # should trigger warning |
| name = expr.display_name() |
| |
| # Verify some things |
| assert len(w) == 1 |
| assert issubclass(w[-1].category, DeprecationWarning) |
| assert "deprecated" in str(w[-1].message) |
| |
| # returns appropriate result |
| assert name == expr.schema_name() |
| assert name == "foo" |
| |
| |
| @pytest.fixture |
| def df(): |
| ctx = SessionContext() |
| |
| # create a RecordBatch and a new DataFrame from it |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, None]), pa.array([4, None, 6]), pa.array([None, None, 8])], |
| names=["a", "b", "c"], |
| ) |
| |
| return ctx.from_arrow(batch) |
| |
| |
| def test_fill_null(df): |
| df = df.select( |
| col("a").fill_null(100).alias("a"), |
| col("b").fill_null(25).alias("b"), |
| col("c").fill_null(1234).alias("c"), |
| ) |
| df.show() |
| result = df.collect()[0] |
| |
| assert result.column(0) == pa.array([1, 2, 100]) |
| assert result.column(1) == pa.array([4, 25, 6]) |
| assert result.column(2) == pa.array([1234, 1234, 8]) |