# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import pyarrow as pa
import pytest
from datafusion import SessionContext, col
from datafusion.expr import (
    Aggregate,
    AggregateFunction,
    BinaryExpr,
    Column,
    Filter,
    Limit,
    Literal,
    Projection,
    Sort,
    TableScan,
)


@pytest.fixture
def test_ctx():
    ctx = SessionContext()
    ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv")
    return ctx


def test_projection(test_ctx):
    df = test_ctx.sql("select c1, 123, c1 < 123 from test")
    plan = df.logical_plan()

    plan = plan.to_variant()
    assert isinstance(plan, Projection)

    expr = plan.projections()

    col1 = expr[0].to_variant()
    assert isinstance(col1, Column)
    assert col1.name() == "c1"
    assert col1.qualified_name() == "test.c1"

    col2 = expr[1].to_variant()
    assert isinstance(col2, Literal)
    assert col2.data_type() == "Int64"
    assert col2.value_i64() == 123

    col3 = expr[2].to_variant()
    assert isinstance(col3, BinaryExpr)
    assert isinstance(col3.left().to_variant(), Column)
    assert col3.op() == "<"
    assert isinstance(col3.right().to_variant(), Literal)

    plan = plan.input()[0].to_variant()
    assert isinstance(plan, TableScan)


def test_filter(test_ctx):
    df = test_ctx.sql("select c1 from test WHERE c1 > 5")
    plan = df.logical_plan()

    plan = plan.to_variant()
    assert isinstance(plan, Projection)

    plan = plan.input()[0].to_variant()
    assert isinstance(plan, Filter)


def test_limit(test_ctx):
    df = test_ctx.sql("select c1 from test LIMIT 10")
    plan = df.logical_plan()

    plan = plan.to_variant()
    assert isinstance(plan, Limit)
    assert "Skip: None" in str(plan)

    df = test_ctx.sql("select c1 from test LIMIT 10 OFFSET 5")
    plan = df.logical_plan()

    plan = plan.to_variant()
    assert isinstance(plan, Limit)
    assert "Skip: Some(Literal(Int64(5)))" in str(plan)


def test_aggregate_query(test_ctx):
    df = test_ctx.sql("select c1, count(*) from test group by c1")
    plan = df.logical_plan()

    projection = plan.to_variant()
    assert isinstance(projection, Projection)

    aggregate = projection.input()[0].to_variant()
    assert isinstance(aggregate, Aggregate)

    col1 = aggregate.group_by_exprs()[0].to_variant()
    assert isinstance(col1, Column)
    assert col1.name() == "c1"
    assert col1.qualified_name() == "test.c1"

    col2 = aggregate.aggregate_exprs()[0].to_variant()
    assert isinstance(col2, AggregateFunction)


def test_sort(test_ctx):
    df = test_ctx.sql("select c1 from test order by c1")
    plan = df.logical_plan()

    plan = plan.to_variant()
    assert isinstance(plan, Sort)


def test_relational_expr(test_ctx):
    ctx = SessionContext()

    batch = pa.RecordBatch.from_arrays(
        [
            pa.array([1, 2, 3]),
            pa.array(["alpha", "beta", "gamma"], type=pa.string_view()),
        ],
        names=["a", "b"],
    )
    df = ctx.create_dataframe([[batch]], name="batch_array")

    assert df.filter(col("a") == 1).count() == 1
    assert df.filter(col("a") != 1).count() == 2
    assert df.filter(col("a") >= 1).count() == 3
    assert df.filter(col("a") > 1).count() == 2
    assert df.filter(col("a") <= 3).count() == 3
    assert df.filter(col("a") < 3).count() == 2

    assert df.filter(col("b") == "beta").count() == 1
    assert df.filter(col("b") != "beta").count() == 2

    assert df.filter(col("a") == "beta").count() == 0


def test_expr_to_variant():
    # Taken from https://github.com/apache/datafusion-python/issues/781
    from datafusion import SessionContext
    from datafusion.expr import Filter

    def traverse_logical_plan(plan):
        cur_node = plan.to_variant()
        if isinstance(cur_node, Filter):
            return cur_node.predicate().to_variant()
        if hasattr(plan, "inputs"):
            for input_plan in plan.inputs():
                res = traverse_logical_plan(input_plan)
                if res is not None:
                    return res
        return None

    ctx = SessionContext()
    data = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"]}
    ctx.from_pydict(data, name="table1")
    query = "SELECT * FROM table1 t1 WHERE t1.name IN ('dfa', 'ad', 'dfre', 'vsa')"
    logical_plan = ctx.sql(query).optimized_logical_plan()
    variant = traverse_logical_plan(logical_plan)
    assert variant is not None
    assert variant.expr().to_variant().qualified_name() == "table1.name"
    assert (
        str(variant.list())
        == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]'  # noqa: E501
    )
    assert not variant.negated()


def test_expr_getitem() -> None:
    ctx = SessionContext()
    data = {
        "array_values": [[1, 2, 3], [4, 5], [6], []],
        "struct_values": [
            {"name": "Alice", "age": 15},
            {"name": "Bob", "age": 14},
            {"name": "Charlie", "age": 13},
            {"name": None, "age": 12},
        ],
    }
    df = ctx.from_pydict(data, name="table1")

    names = df.select(col("struct_values")["name"].alias("name")).collect()
    names = [r.as_py() for rs in names for r in rs["name"]]

    array_values = df.select(col("array_values")[1].alias("value")).collect()
    array_values = [r.as_py() for rs in array_values for r in rs["value"]]

    assert names == ["Alice", "Bob", "Charlie", None]
    assert array_values == [2, 5, None, None]


def test_display_name_deprecation():
    import warnings

    expr = col("foo")
    with warnings.catch_warnings(record=True) as w:
        # Cause all warnings to always be triggered
        warnings.simplefilter("always")

        # should trigger warning
        name = expr.display_name()

        # Verify some things
        assert len(w) == 1
        assert issubclass(w[-1].category, DeprecationWarning)
        assert "deprecated" in str(w[-1].message)

    # returns appropriate result
    assert name == expr.schema_name()
    assert name == "foo"


@pytest.fixture
def df():
    ctx = SessionContext()

    # create a RecordBatch and a new DataFrame from it
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, None]), pa.array([4, None, 6]), pa.array([None, None, 8])],
        names=["a", "b", "c"],
    )

    return ctx.from_arrow(batch)


def test_fill_null(df):
    df = df.select(
        col("a").fill_null(100).alias("a"),
        col("b").fill_null(25).alias("b"),
        col("c").fill_null(1234).alias("c"),
    )
    df.show()
    result = df.collect()[0]

    assert result.column(0) == pa.array([1, 2, 100])
    assert result.column(1) == pa.array([4, 25, 6])
    assert result.column(2) == pa.array([1234, 1234, 8])
