blob: 143eea6ff00d63da7bc4b95a4f3190796d9cb0e6 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datafusion import SessionContext
from datafusion.expr import Column, Literal, BinaryExpr, AggregateFunction
from datafusion.expr import (
Projection,
Filter,
Aggregate,
Limit,
Sort,
TableScan,
)
import pytest
@pytest.fixture
def test_ctx():
ctx = SessionContext()
ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv")
return ctx
def test_projection(test_ctx):
df = test_ctx.sql("select c1, 123, c1 < 123 from test")
plan = df.logical_plan()
plan = plan.to_variant()
assert isinstance(plan, Projection)
expr = plan.projections()
col1 = expr[0].to_variant()
assert isinstance(col1, Column)
assert col1.name() == "c1"
assert col1.qualified_name() == "test.c1"
col2 = expr[1].to_variant()
assert isinstance(col2, Literal)
assert col2.data_type() == "Int64"
assert col2.value_i64() == 123
col3 = expr[2].to_variant()
assert isinstance(col3, BinaryExpr)
assert isinstance(col3.left().to_variant(), Column)
assert col3.op() == "<"
assert isinstance(col3.right().to_variant(), Literal)
plan = plan.input().to_variant()
assert isinstance(plan, TableScan)
def test_filter(test_ctx):
df = test_ctx.sql("select c1 from test WHERE c1 > 5")
plan = df.logical_plan()
plan = plan.to_variant()
assert isinstance(plan, Projection)
plan = plan.input().to_variant()
assert isinstance(plan, Filter)
def test_limit(test_ctx):
df = test_ctx.sql("select c1 from test LIMIT 10")
plan = df.logical_plan()
plan = plan.to_variant()
assert isinstance(plan, Limit)
def test_aggregate_query(test_ctx):
df = test_ctx.sql("select c1, count(*) from test group by c1")
plan = df.logical_plan()
projection = plan.to_variant()
assert isinstance(projection, Projection)
aggregate = projection.input().to_variant()
assert isinstance(aggregate, Aggregate)
col1 = aggregate.group_by_exprs()[0].to_variant()
assert isinstance(col1, Column)
assert col1.name() == "c1"
assert col1.qualified_name() == "test.c1"
col2 = aggregate.aggregate_exprs()[0].to_variant()
assert isinstance(col2, AggregateFunction)
def test_sort(test_ctx):
df = test_ctx.sql("select c1 from test order by c1")
plan = df.logical_plan()
plan = plan.to_variant()
assert isinstance(plan, Sort)