blob: 1a41120a554daa1a14226cf3680a959660616a3c [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datafusion import SessionContext, col
from datafusion.expr import Column, Literal, BinaryExpr, AggregateFunction
from datafusion.expr import (
Projection,
Filter,
Aggregate,
Limit,
Sort,
TableScan,
)
import pyarrow
import pytest
@pytest.fixture
def test_ctx():
ctx = SessionContext()
ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv")
return ctx
def test_projection(test_ctx):
df = test_ctx.sql("select c1, 123, c1 < 123 from test")
plan = df.logical_plan()
plan = plan.to_variant()
assert isinstance(plan, Projection)
expr = plan.projections()
col1 = expr[0].to_variant()
assert isinstance(col1, Column)
assert col1.name() == "c1"
assert col1.qualified_name() == "test.c1"
col2 = expr[1].to_variant()
assert isinstance(col2, Literal)
assert col2.data_type() == "Int64"
assert col2.value_i64() == 123
col3 = expr[2].to_variant()
assert isinstance(col3, BinaryExpr)
assert isinstance(col3.left().to_variant(), Column)
assert col3.op() == "<"
assert isinstance(col3.right().to_variant(), Literal)
plan = plan.input()[0].to_variant()
assert isinstance(plan, TableScan)
def test_filter(test_ctx):
df = test_ctx.sql("select c1 from test WHERE c1 > 5")
plan = df.logical_plan()
plan = plan.to_variant()
assert isinstance(plan, Projection)
plan = plan.input()[0].to_variant()
assert isinstance(plan, Filter)
def test_limit(test_ctx):
df = test_ctx.sql("select c1 from test LIMIT 10")
plan = df.logical_plan()
plan = plan.to_variant()
assert isinstance(plan, Limit)
assert plan.skip() == 0
df = test_ctx.sql("select c1 from test LIMIT 10 OFFSET 5")
plan = df.logical_plan()
plan = plan.to_variant()
assert isinstance(plan, Limit)
assert plan.skip() == 5
def test_aggregate_query(test_ctx):
df = test_ctx.sql("select c1, count(*) from test group by c1")
plan = df.logical_plan()
projection = plan.to_variant()
assert isinstance(projection, Projection)
aggregate = projection.input()[0].to_variant()
assert isinstance(aggregate, Aggregate)
col1 = aggregate.group_by_exprs()[0].to_variant()
assert isinstance(col1, Column)
assert col1.name() == "c1"
assert col1.qualified_name() == "test.c1"
col2 = aggregate.aggregate_exprs()[0].to_variant()
assert isinstance(col2, AggregateFunction)
def test_sort(test_ctx):
df = test_ctx.sql("select c1 from test order by c1")
plan = df.logical_plan()
plan = plan.to_variant()
assert isinstance(plan, Sort)
def test_relational_expr(test_ctx):
ctx = SessionContext()
batch = pyarrow.RecordBatch.from_arrays(
[pyarrow.array([1, 2, 3]), pyarrow.array(["alpha", "beta", "gamma"])],
names=["a", "b"],
)
df = ctx.create_dataframe([[batch]], name="batch_array")
assert df.filter(col("a") == 1).count() == 1
assert df.filter(col("a") != 1).count() == 2
assert df.filter(col("a") >= 1).count() == 3
assert df.filter(col("a") > 1).count() == 2
assert df.filter(col("a") <= 3).count() == 3
assert df.filter(col("a") < 3).count() == 2
assert df.filter(col("b") == "beta").count() == 1
assert df.filter(col("b") != "beta").count() == 2
assert df.filter(col("a") == "beta").count() == 0
def test_expr_to_variant():
# Taken from https://github.com/apache/datafusion-python/issues/781
from datafusion import SessionContext
from datafusion.expr import Filter
def traverse_logical_plan(plan):
cur_node = plan.to_variant()
if isinstance(cur_node, Filter):
return cur_node.predicate().to_variant()
if hasattr(plan, 'inputs'):
for input_plan in plan.inputs():
res = traverse_logical_plan(input_plan)
if res is not None:
return res
ctx = SessionContext()
data = {'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie']}
ctx.from_pydict(data, name='table1')
query = "SELECT * FROM table1 t1 WHERE t1.name IN ('dfa', 'ad', 'dfre', 'vsa')"
logical_plan = ctx.sql(query).optimized_logical_plan()
variant = traverse_logical_plan(logical_plan)
assert variant is not None
assert variant.expr().to_variant().qualified_name() == 'table1.name'
assert str(variant.list()) == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]'
assert not variant.negated()