# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import ctypes
import datetime
import itertools
import os
import re
import threading
import time
from pathlib import Path
from typing import Any

import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from datafusion import (
    DataFrame,
    InsertOp,
    ParquetColumnOptions,
    ParquetWriterOptions,
    RecordBatch,
    SessionContext,
    WindowFrame,
    column,
    literal,
)
from datafusion import (
    col as df_col,
)
from datafusion import (
    functions as f,
)
from datafusion.dataframe import DataFrameWriteOptions
from datafusion.dataframe_formatter import (
    DataFrameHtmlFormatter,
    configure_formatter,
    get_formatter,
    reset_formatter,
)
from datafusion.expr import EXPR_TYPE_ERROR, Window
from pyarrow.csv import write_csv

pa_cffi = pytest.importorskip("pyarrow.cffi")

MB = 1024 * 1024


@pytest.fixture
def ctx():
    return SessionContext()


@pytest.fixture
def df(ctx):
    # create a RecordBatch and a new DataFrame from it
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])],
        names=["a", "b", "c"],
    )

    return ctx.from_arrow(batch)


@pytest.fixture
def large_df():
    ctx = SessionContext()

    rows = 100000
    data = {
        "a": list(range(rows)),
        "b": [f"s-{i}" for i in range(rows)],
        "c": [float(i + 0.1) for i in range(rows)],
    }
    batch = pa.record_batch(data)

    return ctx.from_arrow(batch)


@pytest.fixture
def struct_df():
    ctx = SessionContext()

    # create a RecordBatch and a new DataFrame from it
    batch = pa.RecordBatch.from_arrays(
        [pa.array([{"c": 1}, {"c": 2}, {"c": 3}]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )

    return ctx.create_dataframe([[batch]])


@pytest.fixture
def nested_df():
    ctx = SessionContext()

    # create a RecordBatch and a new DataFrame from it
    # Intentionally make each array of different length
    batch = pa.RecordBatch.from_arrays(
        [pa.array([[1], [2, 3], [4, 5, 6], None]), pa.array([7, 8, 9, 10])],
        names=["a", "b"],
    )

    return ctx.create_dataframe([[batch]])


@pytest.fixture
def aggregate_df():
    ctx = SessionContext()
    ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv")
    return ctx.sql("select c1, sum(c2) from test group by c1")


@pytest.fixture
def partitioned_df():
    ctx = SessionContext()

    # create a RecordBatch and a new DataFrame from it
    batch = pa.RecordBatch.from_arrays(
        [
            pa.array([0, 1, 2, 3, 4, 5, 6]),
            pa.array([7, None, 7, 8, 9, None, 9]),
            pa.array(["A", "A", "A", "A", "B", "B", "B"]),
        ],
        names=["a", "b", "c"],
    )

    return ctx.create_dataframe([[batch]])


@pytest.fixture
def clean_formatter_state():
    """Reset the HTML formatter after each test."""
    reset_formatter()


@pytest.fixture
def null_df():
    """Create a DataFrame with null values of different types."""
    ctx = SessionContext()

    # Create a RecordBatch with nulls across different types
    batch = pa.RecordBatch.from_arrays(
        [
            pa.array([1, None, 3, None], type=pa.int64()),
            pa.array([4.5, 6.7, None, None], type=pa.float64()),
            pa.array(["a", None, "c", None], type=pa.string()),
            pa.array([True, None, False, None], type=pa.bool_()),
            pa.array(
                [10957, None, 18993, None], type=pa.date32()
            ),  # 2000-01-01, null, 2022-01-01, null
            pa.array(
                [946684800000, None, 1640995200000, None], type=pa.date64()
            ),  # 2000-01-01, null, 2022-01-01, null
        ],
        names=[
            "int_col",
            "float_col",
            "str_col",
            "bool_col",
            "date32_col",
            "date64_col",
        ],
    )

    return ctx.create_dataframe([[batch]])


# custom style for testing with html formatter
class CustomStyleProvider:
    def get_cell_style(self) -> str:
        return (
            "background-color: #f5f5f5; color: #333; padding: 8px; border: "
            "1px solid #ddd;"
        )

    def get_header_style(self) -> str:
        return (
            "background-color: #4285f4; color: white; font-weight: bold; "
            "padding: 10px; border: 1px solid #3367d6;"
        )


def count_table_rows(html_content: str) -> int:
    """Count the number of table rows in HTML content.
    Args:
        html_content: HTML string to analyze
    Returns:
        Number of table rows found (number of <tr> tags)
    """
    return len(re.findall(r"<tr", html_content))


def test_select(df):
    df_1 = df.select(
        column("a") + column("b"),
        column("a") - column("b"),
    )

    # execute and collect the first (and only) batch
    result = df_1.collect()[0]

    assert result.column(0) == pa.array([5, 7, 9])
    assert result.column(1) == pa.array([-3, -3, -3])

    df_2 = df.select("b", "a")

    # execute and collect the first (and only) batch
    result = df_2.collect()[0]

    assert result.column(0) == pa.array([4, 5, 6])
    assert result.column(1) == pa.array([1, 2, 3])


def test_select_exprs(df):
    df_1 = df.select_exprs(
        "a + b",
        "a - b",
    )

    # execute and collect the first (and only) batch
    result = df_1.collect()[0]

    assert result.column(0) == pa.array([5, 7, 9])
    assert result.column(1) == pa.array([-3, -3, -3])

    df_2 = df.select_exprs("b", "a")

    # execute and collect the first (and only) batch
    result = df_2.collect()[0]

    assert result.column(0) == pa.array([4, 5, 6])
    assert result.column(1) == pa.array([1, 2, 3])

    df_3 = df.select_exprs(
        "abs(a + b)",
        "abs(a - b)",
    )

    # execute and collect the first (and only) batch
    result = df_3.collect()[0]

    assert result.column(0) == pa.array([5, 7, 9])
    assert result.column(1) == pa.array([3, 3, 3])


def test_drop_quoted_columns():
    ctx = SessionContext()
    batch = pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], names=["ID_For_Students"])
    df = ctx.create_dataframe([[batch]])

    # Both should work
    assert df.drop('"ID_For_Students"').schema().names == []
    assert df.drop("ID_For_Students").schema().names == []


def test_select_mixed_expr_string(df):
    df = df.select(column("b"), "a")

    # execute and collect the first (and only) batch
    result = df.collect()[0]

    assert result.column(0) == pa.array([4, 5, 6])
    assert result.column(1) == pa.array([1, 2, 3])


def test_select_unsupported(df):
    with pytest.raises(
        TypeError,
        match=f"Expected Expr or column name.*{re.escape(EXPR_TYPE_ERROR)}",
    ):
        df.select(1)


def test_filter(df):
    df1 = df.filter(column("a") > literal(2)).select(
        column("a") + column("b"),
        column("a") - column("b"),
    )

    # execute and collect the first (and only) batch
    result = df1.collect()[0]

    assert result.column(0) == pa.array([9])
    assert result.column(1) == pa.array([-3])

    df.show()
    # verify that if there is no filter applied, internal dataframe is unchanged
    df2 = df.filter()
    assert df.df == df2.df

    df3 = df.filter(column("a") > literal(1), column("b") != literal(6))
    result = df3.collect()[0]

    assert result.column(0) == pa.array([2])
    assert result.column(1) == pa.array([5])
    assert result.column(2) == pa.array([5])


def test_filter_string_predicates(df):
    df_str = df.filter("a > 2")
    result = df_str.collect()[0]

    assert result.column(0) == pa.array([3])
    assert result.column(1) == pa.array([6])
    assert result.column(2) == pa.array([8])

    df_mixed = df.filter("a > 1", column("b") != literal(6))
    result_mixed = df_mixed.collect()[0]

    assert result_mixed.column(0) == pa.array([2])
    assert result_mixed.column(1) == pa.array([5])
    assert result_mixed.column(2) == pa.array([5])

    df_strings = df.filter("a > 1", "b < 6")
    result_strings = df_strings.collect()[0]

    assert result_strings.column(0) == pa.array([2])
    assert result_strings.column(1) == pa.array([5])
    assert result_strings.column(2) == pa.array([5])


def test_parse_sql_expr(df):
    plan1 = df.filter(df.parse_sql_expr("a > 2")).logical_plan()
    plan2 = df.filter(column("a") > literal(2)).logical_plan()
    # object equality not implemented but string representation should match
    assert str(plan1) == str(plan2)

    df1 = df.filter(df.parse_sql_expr("a > 2")).select(
        column("a") + column("b"),
        column("a") - column("b"),
    )

    # execute and collect the first (and only) batch
    result = df1.collect()[0]

    assert result.column(0) == pa.array([9])
    assert result.column(1) == pa.array([-3])

    df.show()
    # verify that if there is no filter applied, internal dataframe is unchanged
    df2 = df.filter()
    assert df.df == df2.df

    df3 = df.filter(df.parse_sql_expr("a > 1"), df.parse_sql_expr("b != 6"))
    result = df3.collect()[0]

    assert result.column(0) == pa.array([2])
    assert result.column(1) == pa.array([5])
    assert result.column(2) == pa.array([5])


def test_show_empty(df, capsys):
    df_empty = df.filter(column("a") > literal(3))
    df_empty.show()
    captured = capsys.readouterr()
    assert "DataFrame has no rows" in captured.out


def test_sort(df):
    df = df.sort(column("b").sort(ascending=False))

    table = pa.Table.from_batches(df.collect())
    expected = {"a": [3, 2, 1], "b": [6, 5, 4], "c": [8, 5, 8]}

    assert table.to_pydict() == expected


def test_sort_string_and_expression_equivalent(df):
    from datafusion import col

    result_str = df.sort("a").to_pydict()
    result_expr = df.sort(col("a")).to_pydict()
    assert result_str == result_expr


def test_sort_unsupported(df):
    with pytest.raises(
        TypeError,
        match=f"Expected Expr or column name.*{re.escape(EXPR_TYPE_ERROR)}",
    ):
        df.sort(1)


def test_aggregate_string_and_expression_equivalent(df):
    from datafusion import col

    result_str = df.aggregate("a", [f.count()]).sort("a").to_pydict()
    result_expr = df.aggregate(col("a"), [f.count()]).sort("a").to_pydict()
    assert result_str == result_expr


def test_aggregate_tuple_group_by(df):
    result_list = df.aggregate(["a"], [f.count()]).sort("a").to_pydict()
    result_tuple = df.aggregate(("a",), [f.count()]).sort("a").to_pydict()
    assert result_tuple == result_list


def test_aggregate_tuple_aggs(df):
    result_list = df.aggregate("a", [f.count()]).sort("a").to_pydict()
    result_tuple = df.aggregate("a", (f.count(),)).sort("a").to_pydict()
    assert result_tuple == result_list


def test_filter_string_equivalent(df):
    df1 = df.filter("a > 1").to_pydict()
    df2 = df.filter(column("a") > literal(1)).to_pydict()
    assert df1 == df2


def test_filter_string_invalid(df):
    with pytest.raises(Exception) as excinfo:
        df.filter("this is not valid sql").collect()
    assert "Expected Expr" not in str(excinfo.value)


def test_drop(df):
    df = df.drop("c")

    # execute and collect the first (and only) batch
    result = df.collect()[0]

    assert df.schema().names == ["a", "b"]
    assert result.column(0) == pa.array([1, 2, 3])
    assert result.column(1) == pa.array([4, 5, 6])


def test_limit(df):
    df = df.limit(1)

    # execute and collect the first (and only) batch
    result = df.collect()[0]

    assert len(result.column(0)) == 1
    assert len(result.column(1)) == 1


def test_limit_with_offset(df):
    # only 3 rows, but limit past the end to ensure that offset is working
    df = df.limit(5, offset=2)

    # execute and collect the first (and only) batch
    result = df.collect()[0]

    assert len(result.column(0)) == 1
    assert len(result.column(1)) == 1


def test_head(df):
    df = df.head(1)

    # execute and collect the first (and only) batch
    result = df.collect()[0]

    assert result.column(0) == pa.array([1])
    assert result.column(1) == pa.array([4])
    assert result.column(2) == pa.array([8])


def test_tail(df):
    df = df.tail(1)

    # execute and collect the first (and only) batch
    result = df.collect()[0]

    assert result.column(0) == pa.array([3])
    assert result.column(1) == pa.array([6])
    assert result.column(2) == pa.array([8])


def test_with_column_sql_expression(df):
    df = df.with_column("c", "a + b")

    # execute and collect the first (and only) batch
    result = df.collect()[0]

    assert result.schema.field(0).name == "a"
    assert result.schema.field(1).name == "b"
    assert result.schema.field(2).name == "c"

    assert result.column(0) == pa.array([1, 2, 3])
    assert result.column(1) == pa.array([4, 5, 6])
    assert result.column(2) == pa.array([5, 7, 9])


def test_with_column(df):
    df = df.with_column("c", column("a") + column("b"))

    # execute and collect the first (and only) batch
    result = df.collect()[0]

    assert result.schema.field(0).name == "a"
    assert result.schema.field(1).name == "b"
    assert result.schema.field(2).name == "c"

    assert result.column(0) == pa.array([1, 2, 3])
    assert result.column(1) == pa.array([4, 5, 6])
    assert result.column(2) == pa.array([5, 7, 9])


def test_with_columns(df):
    df = df.with_columns(
        (column("a") + column("b")).alias("c"),
        (column("a") + column("b")).alias("d"),
        [
            (column("a") + column("b")).alias("e"),
            (column("a") + column("b")).alias("f"),
        ],
        g=(column("a") + column("b")),
    )

    # execute and collect the first (and only) batch
    result = df.collect()[0]

    assert result.schema.field(0).name == "a"
    assert result.schema.field(1).name == "b"
    assert result.schema.field(2).name == "c"
    assert result.schema.field(3).name == "d"
    assert result.schema.field(4).name == "e"
    assert result.schema.field(5).name == "f"
    assert result.schema.field(6).name == "g"

    assert result.column(0) == pa.array([1, 2, 3])
    assert result.column(1) == pa.array([4, 5, 6])
    assert result.column(2) == pa.array([5, 7, 9])
    assert result.column(3) == pa.array([5, 7, 9])
    assert result.column(4) == pa.array([5, 7, 9])
    assert result.column(5) == pa.array([5, 7, 9])
    assert result.column(6) == pa.array([5, 7, 9])


def test_with_columns_str(df):
    df = df.with_columns(
        "a + b as c",
        "a + b as d",
        [
            "a + b as e",
            "a + b as f",
        ],
        g="a + b",
    )

    # execute and collect the first (and only) batch
    result = df.collect()[0]

    assert result.schema.field(0).name == "a"
    assert result.schema.field(1).name == "b"
    assert result.schema.field(2).name == "c"
    assert result.schema.field(3).name == "d"
    assert result.schema.field(4).name == "e"
    assert result.schema.field(5).name == "f"
    assert result.schema.field(6).name == "g"

    assert result.column(0) == pa.array([1, 2, 3])
    assert result.column(1) == pa.array([4, 5, 6])
    assert result.column(2) == pa.array([5, 7, 9])
    assert result.column(3) == pa.array([5, 7, 9])
    assert result.column(4) == pa.array([5, 7, 9])
    assert result.column(5) == pa.array([5, 7, 9])
    assert result.column(6) == pa.array([5, 7, 9])


def test_cast(df):
    df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
    expected = pa.schema(
        [("a", pa.float16()), ("b", pa.list_(pa.uint32())), ("c", pa.int64())]
    )

    assert df.schema() == expected


def test_iter_batches(df):
    batches = []
    for batch in df:
        batches.append(batch)  # noqa: PERF402

    # Delete DataFrame to ensure RecordBatches remain valid
    del df

    assert len(batches) == 1

    batch = batches[0]
    assert isinstance(batch, RecordBatch)
    pa_batch = batch.to_pyarrow()
    assert pa_batch.column(0).to_pylist() == [1, 2, 3]
    assert pa_batch.column(1).to_pylist() == [4, 5, 6]
    assert pa_batch.column(2).to_pylist() == [8, 5, 8]


def test_iter_returns_datafusion_recordbatch(df):
    for batch in df:
        assert isinstance(batch, RecordBatch)


def test_execute_stream_basic(df):
    stream = df.execute_stream()
    batches = list(stream)

    assert len(batches) == 1
    assert isinstance(batches[0], RecordBatch)
    pa_batch = batches[0].to_pyarrow()
    assert pa_batch.column(0).to_pylist() == [1, 2, 3]
    assert pa_batch.column(1).to_pylist() == [4, 5, 6]
    assert pa_batch.column(2).to_pylist() == [8, 5, 8]


def test_with_column_renamed(df):
    df = df.with_column("c", column("a") + column("b")).with_column_renamed("c", "sum")

    result = df.collect()[0]

    assert result.schema.field(0).name == "a"
    assert result.schema.field(1).name == "b"
    assert result.schema.field(2).name == "sum"


def test_unnest(nested_df):
    nested_df = nested_df.unnest_columns("a")

    # execute and collect the first (and only) batch
    result = nested_df.collect()[0]

    assert result.column(0) == pa.array([1, 2, 3, 4, 5, 6, None])
    assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9, 10])


def test_unnest_without_nulls(nested_df):
    nested_df = nested_df.unnest_columns("a", preserve_nulls=False)

    # execute and collect the first (and only) batch
    result = nested_df.collect()[0]

    assert result.column(0) == pa.array([1, 2, 3, 4, 5, 6])
    assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9])


def test_join():
    ctx = SessionContext()

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    df = ctx.create_dataframe([[batch]], "l")

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2]), pa.array([8, 10])],
        names=["a", "c"],
    )
    df1 = ctx.create_dataframe([[batch]], "r")

    df2 = df.join(df1, on="a", how="inner")
    df2 = df2.sort(column("l.a"))
    table = pa.Table.from_batches(df2.collect())

    expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
    assert table.to_pydict() == expected

    # Test the default behavior for dropping duplicate keys
    # Since we may have a duplicate column name and pa.Table()
    # hides the fact, instead we need to explicitly check the
    # resultant arrays.
    df2 = df.join(df1, left_on="a", right_on="a", how="inner", drop_duplicate_keys=True)
    df2 = df2.sort(column("l.a"))
    result = df2.collect()[0]
    assert result.num_columns == 3
    assert result.column(0) == pa.array([1, 2], pa.int64())
    assert result.column(1) == pa.array([4, 5], pa.int64())
    assert result.column(2) == pa.array([8, 10], pa.int64())

    df2 = df.join(
        df1, left_on="a", right_on="a", how="inner", drop_duplicate_keys=False
    )
    df2 = df2.sort(column("l.a"))
    result = df2.collect()[0]
    assert result.num_columns == 4
    assert result.column(0) == pa.array([1, 2], pa.int64())
    assert result.column(1) == pa.array([4, 5], pa.int64())
    assert result.column(2) == pa.array([1, 2], pa.int64())
    assert result.column(3) == pa.array([8, 10], pa.int64())

    # Verify we don't make a breaking change to pre-43.0.0
    # where users would pass join_keys as a positional argument
    df2 = df.join(df1, (["a"], ["a"]), how="inner")
    df2 = df2.sort(column("l.a"))
    table = pa.Table.from_batches(df2.collect())

    expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
    assert table.to_pydict() == expected


def test_join_invalid_params():
    ctx = SessionContext()

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    df = ctx.create_dataframe([[batch]], "l")

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2]), pa.array([8, 10])],
        names=["a", "c"],
    )
    df1 = ctx.create_dataframe([[batch]], "r")

    with pytest.deprecated_call():
        df2 = df.join(df1, join_keys=(["a"], ["a"]), how="inner")
        df2.show()
        df2 = df2.sort(column("l.a"))
        table = pa.Table.from_batches(df2.collect())

        expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
        assert table.to_pydict() == expected

    with pytest.raises(
        ValueError, match=r"`left_on` or `right_on` should not provided with `on`"
    ):
        df2 = df.join(df1, on="a", how="inner", right_on="test")

    with pytest.raises(
        ValueError, match=r"`left_on` and `right_on` should both be provided."
    ):
        df2 = df.join(df1, left_on="a", how="inner")

    with pytest.raises(
        ValueError, match=r"either `on` or `left_on` and `right_on` should be provided."
    ):
        df2 = df.join(df1, how="inner")


def test_join_on():
    ctx = SessionContext()

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    df = ctx.create_dataframe([[batch]], "l")

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2]), pa.array([-8, 10])],
        names=["a", "c"],
    )
    df1 = ctx.create_dataframe([[batch]], "r")

    df2 = df.join_on(df1, column("l.a").__eq__(column("r.a")), how="inner")
    df2.show()
    df2 = df2.sort(column("l.a"))
    table = pa.Table.from_batches(df2.collect())

    expected = {"a": [1, 2], "c": [-8, 10], "b": [4, 5]}
    assert table.to_pydict() == expected

    df3 = df.join_on(
        df1,
        column("l.a").__eq__(column("r.a")),
        column("l.a").__lt__(column("r.c")),
        how="inner",
    )
    df3.show()
    df3 = df3.sort(column("l.a"))
    table = pa.Table.from_batches(df3.collect())
    expected = {"a": [2], "c": [10], "b": [5]}
    assert table.to_pydict() == expected


def test_join_on_invalid_expr():
    ctx = SessionContext()

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2]), pa.array([4, 5])],
        names=["a", "b"],
    )
    df = ctx.create_dataframe([[batch]], "l")
    df1 = ctx.create_dataframe([[batch]], "r")

    with pytest.raises(
        TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)"
    ):
        df.join_on(df1, "a")


def test_aggregate_invalid_aggs(df):
    with pytest.raises(
        TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)"
    ):
        df.aggregate([], "a")


def test_distinct():
    ctx = SessionContext()

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3, 1, 2, 3]), pa.array([4, 5, 6, 4, 5, 6])],
        names=["a", "b"],
    )
    df_a = ctx.create_dataframe([[batch]]).distinct().sort(column("a"))

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    df_b = ctx.create_dataframe([[batch]]).sort(column("a"))

    assert df_a.collect() == df_b.collect()


data_test_window_functions = [
    (
        "row",
        f.row_number(order_by=[column("b"), column("a").sort(ascending=False)]),
        [4, 2, 3, 5, 7, 1, 6],
    ),
    (
        "row_w_params",
        f.row_number(
            order_by=[column("b"), column("a")],
            partition_by=[column("c")],
        ),
        [2, 1, 3, 4, 2, 1, 3],
    ),
    (
        "row_w_params_no_lists",
        f.row_number(
            order_by=column("b"),
            partition_by=column("c"),
        ),
        [2, 1, 3, 4, 2, 1, 3],
    ),
    ("rank", f.rank(order_by=[column("b")]), [3, 1, 3, 5, 6, 1, 6]),
    (
        "rank_w_params",
        f.rank(order_by=[column("b"), column("a")], partition_by=[column("c")]),
        [2, 1, 3, 4, 2, 1, 3],
    ),
    (
        "rank_w_params_no_lists",
        f.rank(order_by=column("a"), partition_by=column("c")),
        [1, 2, 3, 4, 1, 2, 3],
    ),
    (
        "dense_rank",
        f.dense_rank(order_by=[column("b")]),
        [2, 1, 2, 3, 4, 1, 4],
    ),
    (
        "dense_rank_w_params",
        f.dense_rank(order_by=[column("b"), column("a")], partition_by=[column("c")]),
        [2, 1, 3, 4, 2, 1, 3],
    ),
    (
        "dense_rank_w_params_no_lists",
        f.dense_rank(order_by=column("a"), partition_by=column("c")),
        [1, 2, 3, 4, 1, 2, 3],
    ),
    (
        "percent_rank",
        f.round(f.percent_rank(order_by=[column("b")]), literal(3)),
        [0.333, 0.0, 0.333, 0.667, 0.833, 0.0, 0.833],
    ),
    (
        "percent_rank_w_params",
        f.round(
            f.percent_rank(
                order_by=[column("b"), column("a")], partition_by=[column("c")]
            ),
            literal(3),
        ),
        [0.333, 0.0, 0.667, 1.0, 0.5, 0.0, 1.0],
    ),
    (
        "percent_rank_w_params_no_lists",
        f.round(
            f.percent_rank(order_by=column("a"), partition_by=column("c")),
            literal(3),
        ),
        [0.0, 0.333, 0.667, 1.0, 0.0, 0.5, 1.0],
    ),
    (
        "cume_dist",
        f.round(f.cume_dist(order_by=[column("b")]), literal(3)),
        [0.571, 0.286, 0.571, 0.714, 1.0, 0.286, 1.0],
    ),
    (
        "cume_dist_w_params",
        f.round(
            f.cume_dist(
                order_by=[column("b"), column("a")], partition_by=[column("c")]
            ),
            literal(3),
        ),
        [0.5, 0.25, 0.75, 1.0, 0.667, 0.333, 1.0],
    ),
    (
        "cume_dist_w_params_no_lists",
        f.round(
            f.cume_dist(order_by=column("a"), partition_by=column("c")),
            literal(3),
        ),
        [0.25, 0.5, 0.75, 1.0, 0.333, 0.667, 1.0],
    ),
    (
        "ntile",
        f.ntile(2, order_by=[column("b")]),
        [1, 1, 1, 2, 2, 1, 2],
    ),
    (
        "ntile_w_params",
        f.ntile(2, order_by=[column("b"), column("a")], partition_by=[column("c")]),
        [1, 1, 2, 2, 1, 1, 2],
    ),
    (
        "ntile_w_params_no_lists",
        f.ntile(2, order_by=column("b"), partition_by=column("c")),
        [1, 1, 2, 2, 1, 1, 2],
    ),
    ("lead", f.lead(column("b"), order_by=[column("b")]), [7, None, 8, 9, 9, 7, None]),
    (
        "lead_w_params",
        f.lead(
            column("b"),
            shift_offset=2,
            default_value=-1,
            order_by=[column("b"), column("a")],
            partition_by=[column("c")],
        ),
        [8, 7, -1, -1, -1, 9, -1],
    ),
    (
        "lead_w_params_no_lists",
        f.lead(
            column("b"),
            shift_offset=2,
            default_value=-1,
            order_by=column("b"),
            partition_by=column("c"),
        ),
        [8, 7, -1, -1, -1, 9, -1],
    ),
    ("lag", f.lag(column("b"), order_by=[column("b")]), [None, None, 7, 7, 8, None, 9]),
    (
        "lag_w_params",
        f.lag(
            column("b"),
            shift_offset=2,
            default_value=-1,
            order_by=[column("b"), column("a")],
            partition_by=[column("c")],
        ),
        [-1, -1, None, 7, -1, -1, None],
    ),
    (
        "lag_w_params_no_lists",
        f.lag(
            column("b"),
            shift_offset=2,
            default_value=-1,
            order_by=column("b"),
            partition_by=column("c"),
        ),
        [-1, -1, None, 7, -1, -1, None],
    ),
    (
        "first_value",
        f.first_value(column("a")).over(
            Window(partition_by=[column("c")], order_by=[column("b")])
        ),
        [1, 1, 1, 1, 5, 5, 5],
    ),
    (
        "first_value_without_list_args",
        f.first_value(column("a")).over(
            Window(partition_by=column("c"), order_by=column("b"))
        ),
        [1, 1, 1, 1, 5, 5, 5],
    ),
    (
        "first_value_order_by_string",
        f.first_value(column("a")).over(
            Window(partition_by=[column("c")], order_by="b")
        ),
        [1, 1, 1, 1, 5, 5, 5],
    ),
    (
        "last_value",
        f.last_value(column("a")).over(
            Window(
                partition_by=[column("c")],
                order_by=[column("b")],
                window_frame=WindowFrame("rows", None, None),
            )
        ),
        [3, 3, 3, 3, 6, 6, 6],
    ),
    (
        "3rd_value",
        f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])),
        [None, None, 7, 7, 7, 7, 7],
    ),
    (
        "avg",
        f.round(f.avg(column("b")).over(Window(order_by=[column("a")])), literal(3)),
        [7.0, 7.0, 7.0, 7.333, 7.75, 7.75, 8.0],
    ),
]


@pytest.mark.parametrize(("name", "expr", "result"), data_test_window_functions)
def test_window_functions(partitioned_df, name, expr, result):
    df = partitioned_df.select(
        column("a"), column("b"), column("c"), f.alias(expr, name)
    )
    df.sort(column("a")).show()
    table = pa.Table.from_batches(df.collect())

    expected = {
        "a": [0, 1, 2, 3, 4, 5, 6],
        "b": [7, None, 7, 8, 9, None, 9],
        "c": ["A", "A", "A", "A", "B", "B", "B"],
        name: result,
    }

    assert table.sort_by("a").to_pydict() == expected


@pytest.mark.parametrize("partition", ["c", df_col("c")])
def test_rank_partition_by_accepts_string(partitioned_df, partition):
    """Passing a string to partition_by should match using col()."""
    df = partitioned_df.select(
        f.rank(order_by=column("a"), partition_by=partition).alias("r")
    )
    table = pa.Table.from_batches(df.sort(column("a")).collect())
    assert table.column("r").to_pylist() == [1, 2, 3, 4, 1, 2, 3]


@pytest.mark.parametrize("partition", ["c", df_col("c")])
def test_window_partition_by_accepts_string(partitioned_df, partition):
    """Window.partition_by accepts string identifiers."""
    expr = f.first_value(column("a")).over(
        Window(partition_by=partition, order_by=column("b"))
    )
    df = partitioned_df.select(expr.alias("fv"))
    table = pa.Table.from_batches(df.sort(column("a")).collect())
    assert table.column("fv").to_pylist() == [1, 1, 1, 1, 5, 5, 5]


@pytest.mark.parametrize(
    ("units", "start_bound", "end_bound"),
    [
        (units, start_bound, end_bound)
        for units in ("rows", "range")
        for start_bound in (None, 0, 1)
        for end_bound in (None, 0, 1)
    ]
    + [
        ("groups", 0, 0),
    ],
)
def test_valid_window_frame(units, start_bound, end_bound):
    WindowFrame(units, start_bound, end_bound)


@pytest.mark.parametrize(
    ("units", "start_bound", "end_bound"),
    [
        ("invalid-units", 0, None),
        ("invalid-units", None, 0),
        ("invalid-units", None, None),
        ("groups", None, 0),
        ("groups", 0, None),
        ("groups", None, None),
    ],
)
def test_invalid_window_frame(units, start_bound, end_bound):
    with pytest.raises(NotImplementedError, match=f"(?i){units}"):
        WindowFrame(units, start_bound, end_bound)


def test_window_frame_defaults_match_postgres(partitioned_df):
    col_a = column("a")

    # When order is not set, the default frame should be unbounded preceding to
    # unbounded following. When order is set, the default frame is unbounded preceding
    # to current row.
    no_order = f.avg(col_a).over(Window()).alias("over_no_order")
    with_order = f.avg(col_a).over(Window(order_by=[col_a])).alias("over_with_order")
    df = partitioned_df.select(col_a, no_order, with_order)

    expected = {
        "a": [0, 1, 2, 3, 4, 5, 6],
        "over_no_order": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
        "over_with_order": [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0],
    }

    assert df.sort(col_a).to_pydict() == expected


def _build_last_value_df(df):
    return df.select(
        f.last_value(column("a"))
        .over(
            Window(
                partition_by=[column("c")],
                order_by=[column("b")],
                window_frame=WindowFrame("rows", None, None),
            )
        )
        .alias("expr"),
        f.last_value(column("a"))
        .over(
            Window(
                partition_by=[column("c")],
                order_by="b",
                window_frame=WindowFrame("rows", None, None),
            )
        )
        .alias("str"),
    )


def _build_nth_value_df(df):
    return df.select(
        f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])).alias("expr"),
        f.nth_value(column("b"), 3).over(Window(order_by="a")).alias("str"),
    )


def _build_rank_df(df):
    return df.select(
        f.rank(order_by=[column("b")]).alias("expr"),
        f.rank(order_by="b").alias("str"),
    )


def _build_array_agg_df(df):
    return df.aggregate(
        [column("c")],
        [
            f.array_agg(column("a"), order_by=[column("a")]).alias("expr"),
            f.array_agg(column("a"), order_by="a").alias("str"),
        ],
    ).sort(column("c"))


@pytest.mark.parametrize(
    ("builder", "expected"),
    [
        pytest.param(_build_last_value_df, [3, 3, 3, 3, 6, 6, 6], id="last_value"),
        pytest.param(_build_nth_value_df, [None, None, 7, 7, 7, 7, 7], id="nth_value"),
        pytest.param(_build_rank_df, [1, 1, 3, 3, 5, 6, 6], id="rank"),
        pytest.param(_build_array_agg_df, [[0, 1, 2, 3], [4, 5, 6]], id="array_agg"),
    ],
)
def test_order_by_string_equivalence(partitioned_df, builder, expected):
    df = builder(partitioned_df)
    table = pa.Table.from_batches(df.collect())
    assert table.column("expr").to_pylist() == expected
    assert table.column("expr").to_pylist() == table.column("str").to_pylist()


def test_html_formatter_cell_dimension(df, clean_formatter_state):
    """Test configuring the HTML formatter with different options."""
    # Configure with custom settings
    configure_formatter(
        max_width=500,
        max_height=200,
        enable_cell_expansion=False,
    )

    html_output = df._repr_html_()

    # Verify our configuration was applied
    assert "max-height: 200px" in html_output
    assert "max-width: 500px" in html_output
    # With cell expansion disabled, we shouldn't see expandable-container elements
    assert "expandable-container" not in html_output


def test_html_formatter_custom_style_provider(df, clean_formatter_state):
    """Test using custom style providers with the HTML formatter."""

    # Configure with custom style provider
    configure_formatter(style_provider=CustomStyleProvider())

    html_output = df._repr_html_()

    # Verify our custom styles were applied
    assert "background-color: #4285f4" in html_output
    assert "color: white" in html_output
    assert "background-color: #f5f5f5" in html_output


def test_html_formatter_type_formatters(df, clean_formatter_state):
    """Test registering custom type formatters for specific data types."""

    # Get current formatter and register custom formatters
    formatter = get_formatter()

    # Format integers with color based on value
    # Using int as the type for the formatter will work since we convert
    # Arrow scalar values to Python native types in _get_cell_value
    def format_int(value):
        return f'<span style="color: {"red" if value > 2 else "blue"}">{value}</span>'

    formatter.register_formatter(int, format_int)

    html_output = df._repr_html_()

    # Our test dataframe has values 1,2,3 so we should see:
    assert '<span style="color: blue">1</span>' in html_output


def test_html_formatter_custom_cell_builder(df, clean_formatter_state):
    """Test using a custom cell builder function."""

    # Create a custom cell builder with distinct styling for different value ranges
    def custom_cell_builder(value, row, col, table_id):
        try:
            num_value = int(value)
            if num_value > 5:  # Values > 5 get green background with indicator
                return (
                    '<td style="background-color: #d9f0d3" '
                    f'data-test="high">{value}-high</td>'
                )
            if num_value < 3:  # Values < 3 get blue background with indicator
                return (
                    '<td style="background-color: #d3e9f0" '
                    f'data-test="low">{value}-low</td>'
                )
        except (ValueError, TypeError):
            pass

        # Default styling for other cells (3, 4, 5)
        return f'<td style="border: 1px solid #ddd" data-test="mid">{value}-mid</td>'

    # Set our custom cell builder
    formatter = get_formatter()
    formatter.set_custom_cell_builder(custom_cell_builder)

    html_output = df._repr_html_()

    # Extract cells with specific styling using regex
    low_cells = re.findall(
        r'<td style="background-color: #d3e9f0"[^>]*>(\d+)-low</td>', html_output
    )
    mid_cells = re.findall(
        r'<td style="border: 1px solid #ddd"[^>]*>(\d+)-mid</td>', html_output
    )
    high_cells = re.findall(
        r'<td style="background-color: #d9f0d3"[^>]*>(\d+)-high</td>', html_output
    )

    # Sort the extracted values for consistent comparison
    low_cells = sorted(map(int, low_cells))
    mid_cells = sorted(map(int, mid_cells))
    high_cells = sorted(map(int, high_cells))

    # Verify specific values have the correct styling applied
    assert low_cells == [1, 2]  # Values < 3
    assert mid_cells == [3, 4, 5, 5]  # Values 3-5
    assert high_cells == [6, 8, 8]  # Values > 5

    # Verify the exact content with styling appears in the output
    assert (
        '<td style="background-color: #d3e9f0" data-test="low">1-low</td>'
        in html_output
    )
    assert (
        '<td style="background-color: #d3e9f0" data-test="low">2-low</td>'
        in html_output
    )
    assert (
        '<td style="border: 1px solid #ddd" data-test="mid">3-mid</td>' in html_output
    )
    assert (
        '<td style="border: 1px solid #ddd" data-test="mid">4-mid</td>' in html_output
    )
    assert (
        '<td style="background-color: #d9f0d3" data-test="high">6-high</td>'
        in html_output
    )
    assert (
        '<td style="background-color: #d9f0d3" data-test="high">8-high</td>'
        in html_output
    )

    # Count occurrences to ensure all cells are properly styled
    assert html_output.count("-low</td>") == 2  # Two low values (1, 2)
    assert html_output.count("-mid</td>") == 4  # Four mid values (3, 4, 5, 5)
    assert html_output.count("-high</td>") == 3  # Three high values (6, 8, 8)

    # Create a custom cell builder that changes background color based on value
    def custom_cell_builder(value, row, col, table_id):
        # Handle numeric values regardless of their exact type
        try:
            num_value = int(value)
            if num_value > 5:  # Values > 5 get green background
                return f'<td style="background-color: #d9f0d3">{value}</td>'
            if num_value < 3:  # Values < 3 get light blue background
                return f'<td style="background-color: #d3e9f0">{value}</td>'
        except (ValueError, TypeError):
            pass

        # Default styling for other cells
        return f'<td style="border: 1px solid #ddd">{value}</td>'

    # Set our custom cell builder
    formatter = get_formatter()
    formatter.set_custom_cell_builder(custom_cell_builder)

    html_output = df._repr_html_()

    # Verify our custom cell styling was applied
    assert "background-color: #d3e9f0" in html_output  # For values 1,2


def test_html_formatter_custom_header_builder(df, clean_formatter_state):
    """Test using a custom header builder function."""

    # Create a custom header builder with tooltips
    def custom_header_builder(field):
        tooltips = {
            "a": "Primary key column",
            "b": "Secondary values",
            "c": "Additional data",
        }
        tooltip = tooltips.get(field.name, "")
        return (
            f'<th style="background-color: #333; color: white" '
            f'title="{tooltip}">{field.name}</th>'
        )

    # Set our custom header builder
    formatter = get_formatter()
    formatter.set_custom_header_builder(custom_header_builder)

    html_output = df._repr_html_()

    # Verify our custom headers were applied
    assert 'title="Primary key column"' in html_output
    assert 'title="Secondary values"' in html_output
    assert "background-color: #333; color: white" in html_output


def test_html_formatter_complex_customization(df, clean_formatter_state):
    """Test combining multiple customization options together."""

    # Create a dark mode style provider
    class DarkModeStyleProvider:
        def get_cell_style(self) -> str:
            return (
                "background-color: #222; color: #eee; "
                "padding: 8px; border: 1px solid #444;"
            )

        def get_header_style(self) -> str:
            return (
                "background-color: #111; color: #fff; padding: 10px; "
                "border: 1px solid #333;"
            )

    # Configure with dark mode style
    configure_formatter(
        max_cell_length=10,
        style_provider=DarkModeStyleProvider(),
        custom_css="""
            .datafusion-table {
                font-family: monospace;
                border-collapse: collapse;
            }
            .datafusion-table tr:hover td {
                background-color: #444 !important;
            }
        """,
    )

    # Add type formatters for special formatting - now working with native int values
    formatter = get_formatter()
    formatter.register_formatter(
        int,
        lambda n: f'<span style="color: {"#5af" if n % 2 == 0 else "#f5a"}">{n}</span>',
    )

    html_output = df._repr_html_()

    # Verify our customizations were applied
    assert "background-color: #222" in html_output
    assert "background-color: #111" in html_output
    assert ".datafusion-table" in html_output
    assert "color: #5af" in html_output  # Even numbers


def test_html_formatter_memory(df, clean_formatter_state):
    """Test the memory and row control parameters in DataFrameHtmlFormatter."""
    configure_formatter(max_memory_bytes=10, min_rows_display=1)
    html_output = df._repr_html_()

    # Count the number of table rows in the output
    tr_count = count_table_rows(html_output)
    # With a tiny memory limit of 10 bytes, the formatter should display
    # the minimum number of rows (1) plus a message about truncation
    assert tr_count == 2  # 1 for header row, 1 for data row
    assert "data truncated" in html_output.lower()

    configure_formatter(max_memory_bytes=10 * MB, min_rows_display=1)
    html_output = df._repr_html_()
    # With larger memory limit and min_rows=2, should display all rows
    tr_count = count_table_rows(html_output)
    # Table should have header row (1) + 3 data rows = 4 rows
    assert tr_count == 4
    # No truncation message should appear
    assert "data truncated" not in html_output.lower()


def test_html_formatter_repr_rows(df, clean_formatter_state):
    configure_formatter(min_rows_display=2, repr_rows=2)
    html_output = df._repr_html_()

    tr_count = count_table_rows(html_output)
    # Table should have header row (1) + 2 data rows = 3 rows
    assert tr_count == 3

    configure_formatter(min_rows_display=2, repr_rows=3)
    html_output = df._repr_html_()

    tr_count = count_table_rows(html_output)
    # Table should have header row (1) + 3 data rows = 4 rows
    assert tr_count == 4


def test_html_formatter_validation():
    # Test validation for invalid parameters

    with pytest.raises(ValueError, match="max_cell_length must be a positive integer"):
        DataFrameHtmlFormatter(max_cell_length=0)

    with pytest.raises(ValueError, match="max_width must be a positive integer"):
        DataFrameHtmlFormatter(max_width=0)

    with pytest.raises(ValueError, match="max_height must be a positive integer"):
        DataFrameHtmlFormatter(max_height=0)

    with pytest.raises(ValueError, match="max_memory_bytes must be a positive integer"):
        DataFrameHtmlFormatter(max_memory_bytes=0)

    with pytest.raises(ValueError, match="max_memory_bytes must be a positive integer"):
        DataFrameHtmlFormatter(max_memory_bytes=-100)

    with pytest.raises(ValueError, match="min_rows_display must be a positive integer"):
        DataFrameHtmlFormatter(min_rows_display=0)

    with pytest.raises(ValueError, match="min_rows_display must be a positive integer"):
        DataFrameHtmlFormatter(min_rows_display=-5)

    with pytest.raises(ValueError, match="repr_rows must be a positive integer"):
        DataFrameHtmlFormatter(repr_rows=0)

    with pytest.raises(ValueError, match="repr_rows must be a positive integer"):
        DataFrameHtmlFormatter(repr_rows=-10)


def test_configure_formatter(df, clean_formatter_state):
    """Test using custom style providers with the HTML formatter and configured
    parameters."""

    # these are non-default values
    max_cell_length = 10
    max_width = 500
    max_height = 30
    max_memory_bytes = 3 * MB
    min_rows_display = 2
    repr_rows = 2
    enable_cell_expansion = False
    show_truncation_message = False
    use_shared_styles = False

    reset_formatter()
    formatter_default = get_formatter()

    assert formatter_default.max_cell_length != max_cell_length
    assert formatter_default.max_width != max_width
    assert formatter_default.max_height != max_height
    assert formatter_default.max_memory_bytes != max_memory_bytes
    assert formatter_default.min_rows_display != min_rows_display
    assert formatter_default.repr_rows != repr_rows
    assert formatter_default.enable_cell_expansion != enable_cell_expansion
    assert formatter_default.show_truncation_message != show_truncation_message
    assert formatter_default.use_shared_styles != use_shared_styles

    # Configure with custom style provider and additional parameters
    configure_formatter(
        max_cell_length=max_cell_length,
        max_width=max_width,
        max_height=max_height,
        max_memory_bytes=max_memory_bytes,
        min_rows_display=min_rows_display,
        repr_rows=repr_rows,
        enable_cell_expansion=enable_cell_expansion,
        show_truncation_message=show_truncation_message,
        use_shared_styles=use_shared_styles,
    )
    formatter_custom = get_formatter()
    assert formatter_custom.max_cell_length == max_cell_length
    assert formatter_custom.max_width == max_width
    assert formatter_custom.max_height == max_height
    assert formatter_custom.max_memory_bytes == max_memory_bytes
    assert formatter_custom.min_rows_display == min_rows_display
    assert formatter_custom.repr_rows == repr_rows
    assert formatter_custom.enable_cell_expansion == enable_cell_expansion
    assert formatter_custom.show_truncation_message == show_truncation_message
    assert formatter_custom.use_shared_styles == use_shared_styles


def test_configure_formatter_invalid_params(clean_formatter_state):
    """Test that configure_formatter rejects invalid parameters."""
    with pytest.raises(ValueError, match="Invalid formatter parameters"):
        configure_formatter(invalid_param=123)

    # Test with multiple parameters, one valid and one invalid
    with pytest.raises(ValueError, match="Invalid formatter parameters"):
        configure_formatter(max_width=500, not_a_real_param="test")

    # Test with multiple invalid parameters
    with pytest.raises(ValueError, match="Invalid formatter parameters"):
        configure_formatter(fake_param1="test", fake_param2=456)


def test_get_dataframe(tmp_path):
    ctx = SessionContext()

    path = tmp_path / "test.csv"
    table = pa.Table.from_arrays(
        [
            [1, 2, 3, 4],
            ["a", "b", "c", "d"],
            [1.1, 2.2, 3.3, 4.4],
        ],
        names=["int", "str", "float"],
    )
    write_csv(table, path)

    ctx.register_csv("csv", path)

    df = ctx.table("csv")
    assert isinstance(df, DataFrame)


def test_struct_select(struct_df):
    df = struct_df.select(
        column("a")["c"] + column("b"),
        column("a")["c"] - column("b"),
    )

    # execute and collect the first (and only) batch
    result = df.collect()[0]

    assert result.column(0) == pa.array([5, 7, 9])
    assert result.column(1) == pa.array([-3, -3, -3])


def test_explain(df):
    df = df.select(
        column("a") + column("b"),
        column("a") - column("b"),
    )
    df.explain()


def test_logical_plan(aggregate_df):
    plan = aggregate_df.logical_plan()

    expected = "Projection: test.c1, sum(test.c2)"

    assert expected == plan.display()

    expected = (
        "Projection: test.c1, sum(test.c2)\n"
        "  Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]\n"
        "    TableScan: test"
    )

    assert expected == plan.display_indent()


def test_optimized_logical_plan(aggregate_df):
    plan = aggregate_df.optimized_logical_plan()

    expected = "Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]"

    assert expected == plan.display()

    expected = (
        "Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]\n"
        "  TableScan: test projection=[c1, c2]"
    )

    assert expected == plan.display_indent()


def test_execution_plan(aggregate_df):
    plan = aggregate_df.execution_plan()

    expected = (
        "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n"
    )

    assert expected == plan.display()

    # Check the number of partitions is as expected.
    assert isinstance(plan.partition_count, int)

    expected = (
        "ProjectionExec: expr=[c1@0 as c1, SUM(test.c2)@1 as SUM(test.c2)]\n"
        "  Aggregate: groupBy=[[test.c1]], aggr=[[SUM(test.c2)]]\n"
        "    TableScan: test projection=[c1, c2]"
    )

    indent = plan.display_indent()

    # indent plan will be different for everyone due to absolute path
    # to filename, so we just check for some expected content
    assert "AggregateExec:" in indent
    assert "CoalesceBatchesExec:" in indent
    assert "RepartitionExec:" in indent
    assert "DataSourceExec:" in indent
    assert "file_type=csv" in indent

    ctx = SessionContext()
    rows_returned = 0
    for idx in range(plan.partition_count):
        stream = ctx.execute(plan, idx)
        try:
            batch = stream.next()
            assert batch is not None
            rows_returned += len(batch.to_pyarrow()[0])
        except StopIteration:
            # This is one of the partitions with no values
            pass
        with pytest.raises(StopIteration):
            stream.next()

    assert rows_returned == 5


@pytest.mark.asyncio
async def test_async_iteration_of_df(aggregate_df):
    rows_returned = 0
    async for batch in aggregate_df:
        assert batch is not None
        rows_returned += len(batch.to_pyarrow()[0])

    assert rows_returned == 5


def test_repartition(df):
    df.repartition(2)


def test_repartition_by_hash(df):
    df.repartition_by_hash(column("a"), num=2)


def test_repartition_by_hash_sql_expression(df):
    df.repartition_by_hash("a", num=2)


def test_repartition_by_hash_mix(df):
    df.repartition_by_hash(column("a"), "b", num=2)


def test_intersect():
    ctx = SessionContext()

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    df_a = ctx.create_dataframe([[batch]])

    batch = pa.RecordBatch.from_arrays(
        [pa.array([3, 4, 5]), pa.array([6, 7, 8])],
        names=["a", "b"],
    )
    df_b = ctx.create_dataframe([[batch]])

    batch = pa.RecordBatch.from_arrays(
        [pa.array([3]), pa.array([6])],
        names=["a", "b"],
    )
    df_c = ctx.create_dataframe([[batch]]).sort(column("a"))

    df_a_i_b = df_a.intersect(df_b).sort(column("a"))

    assert df_c.collect() == df_a_i_b.collect()


def test_except_all():
    ctx = SessionContext()

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    df_a = ctx.create_dataframe([[batch]])

    batch = pa.RecordBatch.from_arrays(
        [pa.array([3, 4, 5]), pa.array([6, 7, 8])],
        names=["a", "b"],
    )
    df_b = ctx.create_dataframe([[batch]])

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2]), pa.array([4, 5])],
        names=["a", "b"],
    )
    df_c = ctx.create_dataframe([[batch]]).sort(column("a"))

    df_a_e_b = df_a.except_all(df_b).sort(column("a"))

    assert df_c.collect() == df_a_e_b.collect()


def test_collect_partitioned():
    ctx = SessionContext()

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )

    assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned()


def test_collect_column(ctx: SessionContext):
    batch_1 = pa.RecordBatch.from_pydict({"a": [1, 2, 3]})
    batch_2 = pa.RecordBatch.from_pydict({"a": [4, 5, 6]})
    batch_3 = pa.RecordBatch.from_pydict({"a": [7, 8, 9]})

    ctx.register_record_batches("t", [[batch_1, batch_2], [batch_3]])

    result = ctx.table("t").sort(column("a")).collect_column("a")
    expected = pa.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
    assert result == expected


def test_union(ctx):
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    df_a = ctx.create_dataframe([[batch]])

    batch = pa.RecordBatch.from_arrays(
        [pa.array([3, 4, 5]), pa.array([6, 7, 8])],
        names=["a", "b"],
    )
    df_b = ctx.create_dataframe([[batch]])

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3, 3, 4, 5]), pa.array([4, 5, 6, 6, 7, 8])],
        names=["a", "b"],
    )
    df_c = ctx.create_dataframe([[batch]]).sort(column("a"))

    df_a_u_b = df_a.union(df_b).sort(column("a"))

    assert df_c.collect() == df_a_u_b.collect()


def test_union_distinct(ctx):
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    df_a = ctx.create_dataframe([[batch]])

    batch = pa.RecordBatch.from_arrays(
        [pa.array([3, 4, 5]), pa.array([6, 7, 8])],
        names=["a", "b"],
    )
    df_b = ctx.create_dataframe([[batch]])

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3, 4, 5]), pa.array([4, 5, 6, 7, 8])],
        names=["a", "b"],
    )
    df_c = ctx.create_dataframe([[batch]]).sort(column("a"))

    df_a_u_b = df_a.union(df_b, distinct=True).sort(column("a"))

    assert df_c.collect() == df_a_u_b.collect()
    assert df_c.collect() == df_a_u_b.collect()


def test_cache(df):
    assert df.cache().collect() == df.collect()


def test_count(df):
    # Get number of rows
    assert df.count() == 3


def test_to_pandas(df):
    # Skip test if pandas is not installed
    pd = pytest.importorskip("pandas")

    # Convert datafusion dataframe to pandas dataframe
    pandas_df = df.to_pandas()
    assert isinstance(pandas_df, pd.DataFrame)
    assert pandas_df.shape == (3, 3)
    assert set(pandas_df.columns) == {"a", "b", "c"}


def test_empty_to_pandas(df):
    # Skip test if pandas is not installed
    pd = pytest.importorskip("pandas")

    # Convert empty datafusion dataframe to pandas dataframe
    pandas_df = df.limit(0).to_pandas()
    assert isinstance(pandas_df, pd.DataFrame)
    assert pandas_df.shape == (0, 3)
    assert set(pandas_df.columns) == {"a", "b", "c"}


def test_to_polars(df):
    # Skip test if polars is not installed
    pl = pytest.importorskip("polars")

    # Convert datafusion dataframe to polars dataframe
    polars_df = df.to_polars()
    assert isinstance(polars_df, pl.DataFrame)
    assert polars_df.shape == (3, 3)
    assert set(polars_df.columns) == {"a", "b", "c"}


def test_empty_to_polars(df):
    # Skip test if polars is not installed
    pl = pytest.importorskip("polars")

    # Convert empty datafusion dataframe to polars dataframe
    polars_df = df.limit(0).to_polars()
    assert isinstance(polars_df, pl.DataFrame)
    assert polars_df.shape == (0, 3)
    assert set(polars_df.columns) == {"a", "b", "c"}


def test_to_arrow_table(df):
    # Convert datafusion dataframe to pyarrow Table
    pyarrow_table = df.to_arrow_table()
    assert isinstance(pyarrow_table, pa.Table)
    assert pyarrow_table.shape == (3, 3)
    assert set(pyarrow_table.column_names) == {"a", "b", "c"}


def test_execute_stream(df):
    stream = df.execute_stream()
    assert all(batch is not None for batch in stream)
    assert not list(stream)  # after one iteration the generator must be exhausted


@pytest.mark.asyncio
async def test_execute_stream_async(df):
    stream = df.execute_stream()
    batches = [batch async for batch in stream]

    assert all(batch is not None for batch in batches)

    # After consuming all batches, the stream should be exhausted
    remaining_batches = [batch async for batch in stream]
    assert not remaining_batches


@pytest.mark.parametrize("schema", [True, False])
def test_execute_stream_to_arrow_table(df, schema):
    stream = df.execute_stream()

    if schema:
        pyarrow_table = pa.Table.from_batches(
            (batch.to_pyarrow() for batch in stream), schema=df.schema()
        )
    else:
        pyarrow_table = pa.Table.from_batches(batch.to_pyarrow() for batch in stream)

    assert isinstance(pyarrow_table, pa.Table)
    assert pyarrow_table.shape == (3, 3)
    assert set(pyarrow_table.column_names) == {"a", "b", "c"}


@pytest.mark.asyncio
@pytest.mark.parametrize("schema", [True, False])
async def test_execute_stream_to_arrow_table_async(df, schema):
    stream = df.execute_stream()

    if schema:
        pyarrow_table = pa.Table.from_batches(
            [batch.to_pyarrow() async for batch in stream], schema=df.schema()
        )
    else:
        pyarrow_table = pa.Table.from_batches(
            [batch.to_pyarrow() async for batch in stream]
        )

    assert isinstance(pyarrow_table, pa.Table)
    assert pyarrow_table.shape == (3, 3)
    assert set(pyarrow_table.column_names) == {"a", "b", "c"}


def test_execute_stream_partitioned(df):
    streams = df.execute_stream_partitioned()
    assert all(batch is not None for stream in streams for batch in stream)
    assert all(
        not list(stream) for stream in streams
    )  # after one iteration all generators must be exhausted


@pytest.mark.asyncio
async def test_execute_stream_partitioned_async(df):
    streams = df.execute_stream_partitioned()

    for stream in streams:
        batches = [batch async for batch in stream]
        assert all(batch is not None for batch in batches)

        # Ensure the stream is exhausted after iteration
        remaining_batches = [batch async for batch in stream]
        assert not remaining_batches


def test_empty_to_arrow_table(df):
    # Convert empty datafusion dataframe to pyarrow Table
    pyarrow_table = df.limit(0).to_arrow_table()
    assert isinstance(pyarrow_table, pa.Table)
    assert pyarrow_table.shape == (0, 3)
    assert set(pyarrow_table.column_names) == {"a", "b", "c"}


def test_iter_batches_dataframe(fail_collect):
    ctx = SessionContext()

    batch1 = pa.record_batch([pa.array([1])], names=["a"])
    batch2 = pa.record_batch([pa.array([2])], names=["a"])
    df = ctx.create_dataframe([[batch1], [batch2]])

    expected = [batch1, batch2]
    results = [b.to_pyarrow() for b in df]

    assert len(results) == len(expected)
    for exp in expected:
        assert any(got.equals(exp) for got in results)


def test_arrow_c_stream_to_table_and_reader(fail_collect):
    ctx = SessionContext()

    # Create a DataFrame with two separate record batches
    batch1 = pa.record_batch([pa.array([1])], names=["a"])
    batch2 = pa.record_batch([pa.array([2])], names=["a"])
    df = ctx.create_dataframe([[batch1], [batch2]])

    table = pa.Table.from_batches(batch.to_pyarrow() for batch in df)
    batches = table.to_batches()

    assert len(batches) == 2
    expected = [batch1, batch2]
    for exp in expected:
        assert any(got.equals(exp) for got in batches)
    assert table.schema == df.schema()
    assert table.column("a").num_chunks == 2

    reader = pa.RecordBatchReader.from_stream(df)
    assert isinstance(reader, pa.RecordBatchReader)
    reader_table = pa.Table.from_batches(reader)
    expected = pa.Table.from_batches([batch1, batch2])
    assert reader_table.equals(expected)


def test_arrow_c_stream_order():
    ctx = SessionContext()

    batch1 = pa.record_batch([pa.array([1])], names=["a"])
    batch2 = pa.record_batch([pa.array([2])], names=["a"])

    df = ctx.create_dataframe([[batch1, batch2]])

    table = pa.Table.from_batches(batch.to_pyarrow() for batch in df)
    expected = pa.Table.from_batches([batch1, batch2])

    assert table.equals(expected)
    col = table.column("a")
    assert col.chunk(0)[0].as_py() == 1
    assert col.chunk(1)[0].as_py() == 2


def test_arrow_c_stream_schema_selection(fail_collect):
    ctx = SessionContext()

    batch = pa.RecordBatch.from_arrays(
        [
            pa.array([1, 2]),
            pa.array([3, 4]),
            pa.array([5, 6]),
        ],
        names=["a", "b", "c"],
    )
    df = ctx.create_dataframe([[batch]])

    requested_schema = pa.schema([("c", pa.int64()), ("a", pa.int64())])

    c_schema = pa_cffi.ffi.new("struct ArrowSchema*")
    address = int(pa_cffi.ffi.cast("uintptr_t", c_schema))
    requested_schema._export_to_c(address)
    capsule_new = ctypes.pythonapi.PyCapsule_New
    capsule_new.restype = ctypes.py_object
    capsule_new.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]

    reader = pa.RecordBatchReader.from_stream(df, schema=requested_schema)

    assert reader.schema == requested_schema

    batches = list(reader)

    assert len(batches) == 1
    expected_batch = pa.record_batch(
        [pa.array([5, 6]), pa.array([1, 2])], names=["c", "a"]
    )
    assert batches[0].equals(expected_batch)


def test_arrow_c_stream_schema_mismatch(fail_collect):
    ctx = SessionContext()

    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2]), pa.array([3, 4])], names=["a", "b"]
    )
    df = ctx.create_dataframe([[batch]])

    bad_schema = pa.schema([("a", pa.string())])

    c_schema = pa_cffi.ffi.new("struct ArrowSchema*")
    address = int(pa_cffi.ffi.cast("uintptr_t", c_schema))
    bad_schema._export_to_c(address)

    capsule_new = ctypes.pythonapi.PyCapsule_New
    capsule_new.restype = ctypes.py_object
    capsule_new.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
    bad_capsule = capsule_new(ctypes.c_void_p(address), b"arrow_schema", None)

    with pytest.raises(Exception, match="Fail to merge schema"):
        df.__arrow_c_stream__(bad_capsule)


def test_to_pylist(df):
    # Convert datafusion dataframe to Python list
    pylist = df.to_pylist()
    assert isinstance(pylist, list)
    assert pylist == [
        {"a": 1, "b": 4, "c": 8},
        {"a": 2, "b": 5, "c": 5},
        {"a": 3, "b": 6, "c": 8},
    ]


def test_to_pydict(df):
    # Convert datafusion dataframe to Python dictionary
    pydict = df.to_pydict()
    assert isinstance(pydict, dict)
    assert pydict == {"a": [1, 2, 3], "b": [4, 5, 6], "c": [8, 5, 8]}


def test_describe(df):
    # Calculate statistics
    df = df.describe()

    # Collect the result
    result = df.to_pydict()

    assert result == {
        "describe": [
            "count",
            "null_count",
            "mean",
            "std",
            "min",
            "max",
            "median",
        ],
        "a": [3.0, 0.0, 2.0, 1.0, 1.0, 3.0, 2.0],
        "b": [3.0, 0.0, 5.0, 1.0, 4.0, 6.0, 5.0],
        "c": [3.0, 0.0, 7.0, 1.7320508075688772, 5.0, 8.0, 8.0],
    }


@pytest.mark.parametrize("path_to_str", [True, False])
def test_write_csv(ctx, df, tmp_path, path_to_str):
    path = str(tmp_path) if path_to_str else tmp_path

    df.write_csv(path, with_header=True)

    ctx.register_csv("csv", path)
    result = ctx.table("csv").to_pydict()
    expected = df.to_pydict()

    assert result == expected


def generate_test_write_params() -> list[tuple]:
    # Overwrite and Replace are not implemented for many table writers
    insert_ops = [InsertOp.APPEND, None]
    sort_by_cases = [
        (None, [1, 2, 3], "unsorted"),
        (column("c"), [2, 1, 3], "single_column_expr"),
        (column("a").sort(ascending=False), [3, 2, 1], "single_sort_expr"),
        ([column("c"), column("b")], [2, 1, 3], "list_col_expr"),
        (
            [column("c").sort(ascending=False), column("b").sort(ascending=False)],
            [3, 1, 2],
            "list_sort_expr",
        ),
    ]

    formats = ["csv", "json", "parquet", "table"]

    return [
        pytest.param(
            output_format,
            insert_op,
            sort_by,
            expected_a,
            id=f"{output_format}_{test_id}",
        )
        for output_format, insert_op, (
            sort_by,
            expected_a,
            test_id,
        ) in itertools.product(formats, insert_ops, sort_by_cases)
    ]


@pytest.mark.parametrize(
    ("output_format", "insert_op", "sort_by", "expected_a"),
    generate_test_write_params(),
)
def test_write_files_with_options(
    ctx, df, tmp_path, output_format, insert_op, sort_by, expected_a
) -> None:
    write_options = DataFrameWriteOptions(insert_operation=insert_op, sort_by=sort_by)

    if output_format == "csv":
        df.write_csv(tmp_path, with_header=True, write_options=write_options)
        ctx.register_csv("test_table", tmp_path)
    elif output_format == "json":
        df.write_json(tmp_path, write_options=write_options)
        ctx.register_json("test_table", tmp_path)
    elif output_format == "parquet":
        df.write_parquet(tmp_path, write_options=write_options)
        ctx.register_parquet("test_table", tmp_path)
    elif output_format == "table":
        batch = pa.RecordBatch.from_arrays([[], [], []], schema=df.schema())
        ctx.register_record_batches("test_table", [[batch]])
        ctx.table("test_table").show()
        df.write_table("test_table", write_options=write_options)

    result = ctx.table("test_table").to_pydict()["a"]
    ctx.table("test_table").show()

    assert result == expected_a


@pytest.mark.parametrize("path_to_str", [True, False])
def test_write_json(ctx, df, tmp_path, path_to_str):
    path = str(tmp_path) if path_to_str else tmp_path

    df.write_json(path)

    ctx.register_json("json", path)
    result = ctx.table("json").to_pydict()
    expected = df.to_pydict()

    assert result == expected


@pytest.mark.parametrize("path_to_str", [True, False])
def test_write_parquet(df, tmp_path, path_to_str):
    path = str(tmp_path) if path_to_str else tmp_path

    df.write_parquet(str(path))
    result = pq.read_table(str(path)).to_pydict()
    expected = df.to_pydict()

    assert result == expected


@pytest.mark.parametrize(
    ("compression", "compression_level"),
    [("gzip", 6), ("brotli", 7), ("zstd", 15)],
)
def test_write_compressed_parquet(df, tmp_path, compression, compression_level):
    path = tmp_path

    df.write_parquet(
        str(path), compression=compression, compression_level=compression_level
    )

    # test that the actual compression scheme is the one written
    for _root, _dirs, files in os.walk(path):
        for file in files:
            if file.endswith(".parquet"):
                metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict()
                for row_group in metadata["row_groups"]:
                    for columns in row_group["columns"]:
                        assert columns["compression"].lower() == compression

    result = pq.read_table(str(path)).to_pydict()
    expected = df.to_pydict()

    assert result == expected


@pytest.mark.parametrize(
    ("compression", "compression_level"),
    [("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)],
)
def test_write_compressed_parquet_wrong_compression_level(
    df, tmp_path, compression, compression_level
):
    path = tmp_path

    with pytest.raises(ValueError):
        df.write_parquet(
            str(path),
            compression=compression,
            compression_level=compression_level,
        )


@pytest.mark.parametrize("compression", ["wrong"])
def test_write_compressed_parquet_invalid_compression(df, tmp_path, compression):
    path = tmp_path

    with pytest.raises(ValueError):
        df.write_parquet(str(path), compression=compression)


# not testing lzo because it it not implemented yet
# https://github.com/apache/arrow-rs/issues/6970
@pytest.mark.parametrize("compression", ["zstd", "brotli", "gzip"])
def test_write_compressed_parquet_default_compression_level(df, tmp_path, compression):
    # Test write_parquet with zstd, brotli, gzip default compression level,
    # ie don't specify compression level
    # should complete without error
    path = tmp_path

    df.write_parquet(str(path), compression=compression)


def test_write_parquet_with_options_default_compression(df, tmp_path):
    """Test that the default compression is ZSTD."""
    df.write_parquet(tmp_path)

    for file in tmp_path.rglob("*.parquet"):
        metadata = pq.ParquetFile(file).metadata.to_dict()
        for row_group in metadata["row_groups"]:
            for col in row_group["columns"]:
                assert col["compression"].lower() == "zstd"


@pytest.mark.parametrize(
    "compression",
    ["gzip(6)", "brotli(7)", "zstd(15)", "snappy", "uncompressed"],
)
def test_write_parquet_with_options_compression(df, tmp_path, compression):
    import re

    path = tmp_path
    df.write_parquet_with_options(
        str(path), ParquetWriterOptions(compression=compression)
    )

    # test that the actual compression scheme is the one written
    for _root, _dirs, files in os.walk(path):
        for file in files:
            if file.endswith(".parquet"):
                metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict()
                for row_group in metadata["row_groups"]:
                    for col in row_group["columns"]:
                        assert col["compression"].lower() == re.sub(
                            r"\(\d+\)", "", compression
                        )

    result = pq.read_table(str(path)).to_pydict()
    expected = df.to_pydict()

    assert result == expected


@pytest.mark.parametrize(
    "compression",
    ["gzip(12)", "brotli(15)", "zstd(23)"],
)
def test_write_parquet_with_options_wrong_compression_level(df, tmp_path, compression):
    path = tmp_path

    with pytest.raises(Exception, match=r"valid compression range .*? exceeded."):
        df.write_parquet_with_options(
            str(path), ParquetWriterOptions(compression=compression)
        )


@pytest.mark.parametrize("compression", ["wrong", "wrong(12)"])
def test_write_parquet_with_options_invalid_compression(df, tmp_path, compression):
    path = tmp_path

    with pytest.raises(Exception, match="Unknown or unsupported parquet compression"):
        df.write_parquet_with_options(
            str(path), ParquetWriterOptions(compression=compression)
        )


@pytest.mark.parametrize(
    ("writer_version", "format_version"),
    [("1.0", "1.0"), ("2.0", "2.6"), (None, "1.0")],
)
def test_write_parquet_with_options_writer_version(
    df, tmp_path, writer_version, format_version
):
    """Test the Parquet writer version. Note that writer_version=2.0 results in
    format_version=2.6"""
    if writer_version is None:
        df.write_parquet_with_options(tmp_path, ParquetWriterOptions())
    else:
        df.write_parquet_with_options(
            tmp_path, ParquetWriterOptions(writer_version=writer_version)
        )

    for file in tmp_path.rglob("*.parquet"):
        parquet = pq.ParquetFile(file)
        metadata = parquet.metadata.to_dict()
        assert metadata["format_version"] == format_version


@pytest.mark.parametrize("writer_version", ["1.2.3", "custom-version", "0"])
def test_write_parquet_with_options_wrong_writer_version(df, tmp_path, writer_version):
    """Test that invalid writer versions in Parquet throw an exception."""
    with pytest.raises(
        Exception, match="Unknown or unsupported parquet writer version"
    ):
        df.write_parquet_with_options(
            tmp_path, ParquetWriterOptions(writer_version=writer_version)
        )


@pytest.mark.parametrize("dictionary_enabled", [True, False, None])
def test_write_parquet_with_options_dictionary_enabled(
    df, tmp_path, dictionary_enabled
):
    """Test enabling/disabling the dictionaries in Parquet."""
    df.write_parquet_with_options(
        tmp_path, ParquetWriterOptions(dictionary_enabled=dictionary_enabled)
    )
    # by default, the dictionary is enabled, so None results in True
    result = dictionary_enabled if dictionary_enabled is not None else True

    for file in tmp_path.rglob("*.parquet"):
        parquet = pq.ParquetFile(file)
        metadata = parquet.metadata.to_dict()

        for row_group in metadata["row_groups"]:
            for col in row_group["columns"]:
                assert col["has_dictionary_page"] == result


@pytest.mark.parametrize(
    ("statistics_enabled", "has_statistics"),
    [("page", True), ("chunk", True), ("none", False), (None, True)],
)
def test_write_parquet_with_options_statistics_enabled(
    df, tmp_path, statistics_enabled, has_statistics
):
    """Test configuring the statistics in Parquet. In pyarrow we can only check for
    column-level statistics, so "page" and "chunk" are tested in the same way."""
    df.write_parquet_with_options(
        tmp_path, ParquetWriterOptions(statistics_enabled=statistics_enabled)
    )

    for file in tmp_path.rglob("*.parquet"):
        parquet = pq.ParquetFile(file)
        metadata = parquet.metadata.to_dict()

        for row_group in metadata["row_groups"]:
            for col in row_group["columns"]:
                if has_statistics:
                    assert col["statistics"] is not None
                else:
                    assert col["statistics"] is None


@pytest.mark.parametrize("max_row_group_size", [1000, 5000, 10000, 100000])
def test_write_parquet_with_options_max_row_group_size(
    large_df, tmp_path, max_row_group_size
):
    """Test configuring the max number of rows per group in Parquet. These test cases
    guarantee that the number of rows for each row group is max_row_group_size, given
    the total number of rows is a multiple of max_row_group_size."""
    path = f"{tmp_path}/t.parquet"
    large_df.write_parquet_with_options(
        path, ParquetWriterOptions(max_row_group_size=max_row_group_size)
    )

    parquet = pq.ParquetFile(path)
    metadata = parquet.metadata.to_dict()
    for row_group in metadata["row_groups"]:
        assert row_group["num_rows"] == max_row_group_size


@pytest.mark.parametrize("created_by", ["datafusion", "datafusion-python", "custom"])
def test_write_parquet_with_options_created_by(df, tmp_path, created_by):
    """Test configuring the created by metadata in Parquet."""
    df.write_parquet_with_options(tmp_path, ParquetWriterOptions(created_by=created_by))

    for file in tmp_path.rglob("*.parquet"):
        parquet = pq.ParquetFile(file)
        metadata = parquet.metadata.to_dict()
        assert metadata["created_by"] == created_by


@pytest.mark.parametrize("statistics_truncate_length", [5, 25, 50])
def test_write_parquet_with_options_statistics_truncate_length(
    df, tmp_path, statistics_truncate_length
):
    """Test configuring the truncate limit in Parquet's row-group-level statistics."""
    ctx = SessionContext()
    data = {
        "a": [
            "a_the_quick_brown_fox_jumps_over_the_lazy_dog",
            "m_the_quick_brown_fox_jumps_over_the_lazy_dog",
            "z_the_quick_brown_fox_jumps_over_the_lazy_dog",
        ],
        "b": ["a_smaller", "m_smaller", "z_smaller"],
    }
    df = ctx.from_arrow(pa.record_batch(data))
    df.write_parquet_with_options(
        tmp_path,
        ParquetWriterOptions(statistics_truncate_length=statistics_truncate_length),
    )

    for file in tmp_path.rglob("*.parquet"):
        parquet = pq.ParquetFile(file)
        metadata = parquet.metadata.to_dict()

        for row_group in metadata["row_groups"]:
            for col in row_group["columns"]:
                statistics = col["statistics"]
                assert len(statistics["min"]) <= statistics_truncate_length
                assert len(statistics["max"]) <= statistics_truncate_length


def test_write_parquet_with_options_default_encoding(tmp_path):
    """Test that, by default, Parquet files are written with dictionary encoding.
    Note that dictionary encoding is not used for boolean values, so it is not tested
    here."""
    ctx = SessionContext()
    data = {
        "a": [1, 2, 3],
        "b": ["1", "2", "3"],
        "c": [1.01, 2.02, 3.03],
    }
    df = ctx.from_arrow(pa.record_batch(data))
    df.write_parquet_with_options(tmp_path, ParquetWriterOptions())

    for file in tmp_path.rglob("*.parquet"):
        parquet = pq.ParquetFile(file)
        metadata = parquet.metadata.to_dict()

        for row_group in metadata["row_groups"]:
            for col in row_group["columns"]:
                assert col["encodings"] == ("PLAIN", "RLE", "RLE_DICTIONARY")


@pytest.mark.parametrize(
    ("encoding", "data_types", "result"),
    [
        ("plain", ["int", "float", "str", "bool"], ("PLAIN", "RLE")),
        ("rle", ["bool"], ("RLE",)),
        ("delta_binary_packed", ["int"], ("RLE", "DELTA_BINARY_PACKED")),
        ("delta_length_byte_array", ["str"], ("RLE", "DELTA_LENGTH_BYTE_ARRAY")),
        ("delta_byte_array", ["str"], ("RLE", "DELTA_BYTE_ARRAY")),
        ("byte_stream_split", ["int", "float"], ("RLE", "BYTE_STREAM_SPLIT")),
    ],
)
def test_write_parquet_with_options_encoding(tmp_path, encoding, data_types, result):
    """Test different encodings in Parquet in their respective support column types."""
    ctx = SessionContext()

    data = {}
    for data_type in data_types:
        if data_type == "int":
            data["int"] = [1, 2, 3]
        elif data_type == "float":
            data["float"] = [1.01, 2.02, 3.03]
        elif data_type == "str":
            data["str"] = ["a", "b", "c"]
        elif data_type == "bool":
            data["bool"] = [True, False, True]

    df = ctx.from_arrow(pa.record_batch(data))
    df.write_parquet_with_options(
        tmp_path, ParquetWriterOptions(encoding=encoding, dictionary_enabled=False)
    )

    for file in tmp_path.rglob("*.parquet"):
        parquet = pq.ParquetFile(file)
        metadata = parquet.metadata.to_dict()

        for row_group in metadata["row_groups"]:
            for col in row_group["columns"]:
                assert col["encodings"] == result


@pytest.mark.parametrize("encoding", ["bit_packed"])
def test_write_parquet_with_options_unsupported_encoding(df, tmp_path, encoding):
    """Test that unsupported Parquet encodings do not work."""
    # BaseException is used since this throws a Rust panic: https://github.com/PyO3/pyo3/issues/3519
    with pytest.raises(BaseException, match="Encoding .*? is not supported"):
        df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding))


@pytest.mark.parametrize("encoding", ["non_existent", "unknown", "plain123"])
def test_write_parquet_with_options_invalid_encoding(df, tmp_path, encoding):
    """Test that invalid Parquet encodings do not work."""
    with pytest.raises(Exception, match="Unknown or unsupported parquet encoding"):
        df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding))


@pytest.mark.parametrize("encoding", ["plain_dictionary", "rle_dictionary"])
def test_write_parquet_with_options_dictionary_encoding_fallback(
    df, tmp_path, encoding
):
    """Test that the dictionary encoding cannot be used as fallback in Parquet."""
    # BaseException is used since this throws a Rust panic: https://github.com/PyO3/pyo3/issues/3519
    with pytest.raises(
        BaseException, match="Dictionary encoding can not be used as fallback encoding"
    ):
        df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding))


def test_write_parquet_with_options_bloom_filter(df, tmp_path):
    """Test Parquet files with and without (default) bloom filters. Since pyarrow does
    not expose any information about bloom filters, the easiest way to confirm that they
    are actually written is to compare the file size."""
    path_no_bloom_filter = tmp_path / "1"
    path_bloom_filter = tmp_path / "2"

    df.write_parquet_with_options(path_no_bloom_filter, ParquetWriterOptions())
    df.write_parquet_with_options(
        path_bloom_filter, ParquetWriterOptions(bloom_filter_on_write=True)
    )

    size_no_bloom_filter = 0
    for file in path_no_bloom_filter.rglob("*.parquet"):
        size_no_bloom_filter += Path(file).stat().st_size

    size_bloom_filter = 0
    for file in path_bloom_filter.rglob("*.parquet"):
        size_bloom_filter += Path(file).stat().st_size

    assert size_no_bloom_filter < size_bloom_filter


def test_write_parquet_with_options_column_options(df, tmp_path):
    """Test writing Parquet files with different options for each column, which replace
    the global configs (when provided)."""
    data = {
        "a": [1, 2, 3],
        "b": ["a", "b", "c"],
        "c": [False, True, False],
        "d": [1.01, 2.02, 3.03],
        "e": [4, 5, 6],
    }

    column_specific_options = {
        "a": ParquetColumnOptions(statistics_enabled="none"),
        "b": ParquetColumnOptions(encoding="plain", dictionary_enabled=False),
        "c": ParquetColumnOptions(
            compression="snappy", encoding="rle", dictionary_enabled=False
        ),
        "d": ParquetColumnOptions(
            compression="zstd(6)",
            encoding="byte_stream_split",
            dictionary_enabled=False,
            statistics_enabled="none",
        ),
        # column "e" will use the global configs
    }

    results = {
        "a": {
            "statistics": False,
            "compression": "brotli",
            "encodings": ("PLAIN", "RLE", "RLE_DICTIONARY"),
        },
        "b": {
            "statistics": True,
            "compression": "brotli",
            "encodings": ("PLAIN", "RLE"),
        },
        "c": {
            "statistics": True,
            "compression": "snappy",
            "encodings": ("RLE",),
        },
        "d": {
            "statistics": False,
            "compression": "zstd",
            "encodings": ("RLE", "BYTE_STREAM_SPLIT"),
        },
        "e": {
            "statistics": True,
            "compression": "brotli",
            "encodings": ("PLAIN", "RLE", "RLE_DICTIONARY"),
        },
    }

    ctx = SessionContext()
    df = ctx.from_arrow(pa.record_batch(data))
    df.write_parquet_with_options(
        tmp_path,
        ParquetWriterOptions(
            compression="brotli(8)", column_specific_options=column_specific_options
        ),
    )

    for file in tmp_path.rglob("*.parquet"):
        parquet = pq.ParquetFile(file)
        metadata = parquet.metadata.to_dict()

        for row_group in metadata["row_groups"]:
            for col in row_group["columns"]:
                column_name = col["path_in_schema"]
                result = results[column_name]
                assert (col["statistics"] is not None) == result["statistics"]
                assert col["compression"].lower() == result["compression"].lower()
                assert col["encodings"] == result["encodings"]


def test_write_parquet_options(df, tmp_path):
    options = ParquetWriterOptions(compression="gzip", compression_level=6)
    df.write_parquet(str(tmp_path), options)

    result = pq.read_table(str(tmp_path)).to_pydict()
    expected = df.to_pydict()

    assert result == expected


def test_write_parquet_options_error(df, tmp_path):
    options = ParquetWriterOptions(compression="gzip", compression_level=6)
    with pytest.raises(ValueError):
        df.write_parquet(str(tmp_path), options, compression_level=1)


def test_write_table(ctx, df):
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3])],
        names=["a"],
    )

    ctx.register_record_batches("t", [[batch]])

    df = ctx.table("t").with_column("a", column("a") * literal(-1))

    ctx.table("t").show()

    df.write_table("t")
    result = ctx.table("t").sort(column("a")).collect()[0][0].to_pylist()
    expected = [-3, -2, -1, 1, 2, 3]

    assert result == expected


def test_dataframe_export(df) -> None:
    # Guarantees that we have the canonical implementation
    # reading our dataframe export
    table = pa.table(df)
    assert table.num_columns == 3
    assert table.num_rows == 3

    desired_schema = pa.schema([("a", pa.int64())])

    # Verify we can request a schema
    table = pa.table(df, schema=desired_schema)
    assert table.num_columns == 1
    assert table.num_rows == 3

    # Expect a table of nulls if the schema don't overlap
    desired_schema = pa.schema([("g", pa.string())])
    table = pa.table(df, schema=desired_schema)
    assert table.num_columns == 1
    assert table.num_rows == 3
    for i in range(3):
        assert table[0][i].as_py() is None

    # Expect an error when we cannot convert schema
    desired_schema = pa.schema([("a", pa.float32())])
    failed_convert = False
    try:
        table = pa.table(df, schema=desired_schema)
    except Exception:
        failed_convert = True
    assert failed_convert

    # Expect an error when we have a not set non-nullable
    desired_schema = pa.schema([("g", pa.string(), False)])
    failed_convert = False
    try:
        table = pa.table(df, schema=desired_schema)
    except Exception:
        failed_convert = True
    assert failed_convert


def test_dataframe_transform(df):
    def add_string_col(df_internal) -> DataFrame:
        return df_internal.with_column("string_col", literal("string data"))

    def add_with_parameter(df_internal, value: Any) -> DataFrame:
        return df_internal.with_column("new_col", literal(value))

    df = df.transform(add_string_col).transform(add_with_parameter, 3)

    result = df.to_pydict()

    assert result["a"] == [1, 2, 3]
    assert result["string_col"] == ["string data" for _i in range(3)]
    assert result["new_col"] == [3 for _i in range(3)]


def test_dataframe_repr_html_structure(df, clean_formatter_state) -> None:
    """Test that DataFrame._repr_html_ produces expected HTML output structure."""

    output = df._repr_html_()

    # Since we've added a fair bit of processing to the html output, lets just verify
    # the values we are expecting in the table exist. Use regex and ignore everything
    # between the <th></th> and <td></td>. We also don't want the closing > on the
    # td and th segments because that is where the formatting data is written.

    headers = ["a", "b", "c"]
    headers = [f"<th(.*?)>{v}</th>" for v in headers]
    header_pattern = "(.*?)".join(headers)
    header_matches = re.findall(header_pattern, output, re.DOTALL)
    assert len(header_matches) == 1

    # Update the pattern to handle values that may be wrapped in spans
    body_data = [[1, 4, 8], [2, 5, 5], [3, 6, 8]]

    body_lines = [
        f"<td(.*?)>(?:<span[^>]*?>)?{v}(?:</span>)?</td>"
        for inner in body_data
        for v in inner
    ]
    body_pattern = "(.*?)".join(body_lines)

    body_matches = re.findall(body_pattern, output, re.DOTALL)

    assert len(body_matches) == 1, "Expected pattern of values not found in HTML output"


def test_dataframe_repr_html_values(df, clean_formatter_state):
    """Test that DataFrame._repr_html_ contains the expected data values."""
    html = df._repr_html_()
    assert html is not None

    # Create a more flexible pattern that handles values being wrapped in spans
    # This pattern will match the sequence of values 1,4,8,2,5,5,3,6,8 regardless
    # of formatting
    pattern = re.compile(
        r"<td[^>]*?>(?:<span[^>]*?>)?1(?:</span>)?</td>.*?"
        r"<td[^>]*?>(?:<span[^>]*?>)?4(?:</span>)?</td>.*?"
        r"<td[^>]*?>(?:<span[^>]*?>)?8(?:</span>)?</td>.*?"
        r"<td[^>]*?>(?:<span[^>]*?>)?2(?:</span>)?</td>.*?"
        r"<td[^>]*?>(?:<span[^>]*?>)?5(?:</span>)?</td>.*?"
        r"<td[^>]*?>(?:<span[^>]*?>)?5(?:</span>)?</td>.*?"
        r"<td[^>]*?>(?:<span[^>]*?>)?3(?:</span>)?</td>.*?"
        r"<td[^>]*?>(?:<span[^>]*?>)?6(?:</span>)?</td>.*?"
        r"<td[^>]*?>(?:<span[^>]*?>)?8(?:</span>)?</td>",
        re.DOTALL,
    )

    # Print debug info if the test fails
    matches = re.findall(pattern, html)
    if not matches:
        print(f"HTML output snippet: {html[:500]}...")  # noqa: T201

    assert len(matches) > 0, "Expected pattern of values not found in HTML output"


def test_html_formatter_shared_styles(df, clean_formatter_state):
    """Test that shared styles work correctly across multiple tables."""

    # First, ensure we're using shared styles
    configure_formatter(use_shared_styles=True)

    html_first = df._repr_html_()
    html_second = df._repr_html_()

    assert "<script>" in html_first
    assert "df-styles" in html_first
    assert "<script>" in html_second
    assert "df-styles" in html_second
    assert "<style>" not in html_first
    assert "<style>" not in html_second


def test_html_formatter_no_shared_styles(df, clean_formatter_state):
    """Test that styles are always included when shared styles are disabled."""

    # Configure formatter to NOT use shared styles
    configure_formatter(use_shared_styles=False)

    html_first = df._repr_html_()
    html_second = df._repr_html_()

    assert "<script>" in html_first
    assert "<script>" in html_second
    assert "df-styles" in html_first
    assert "df-styles" in html_second
    assert "<style>" not in html_first
    assert "<style>" not in html_second


def test_html_formatter_manual_format_html(clean_formatter_state):
    """Test direct usage of format_html method with shared styles."""

    # Create sample data
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )

    formatter = get_formatter()

    html_first = formatter.format_html([batch], batch.schema)
    html_second = formatter.format_html([batch], batch.schema)

    assert "<script>" in html_first
    assert "<script>" in html_second
    assert "df-styles" in html_first
    assert "df-styles" in html_second
    assert "<style>" not in html_first
    assert "<style>" not in html_second

    # Create a new formatter with shared_styles=False
    local_formatter = DataFrameHtmlFormatter(use_shared_styles=False)

    # Both calls should include styles
    local_html_1 = local_formatter.format_html([batch], batch.schema)
    local_html_2 = local_formatter.format_html([batch], batch.schema)

    assert "<script>" in local_html_1
    assert "<script>" in local_html_2
    assert "df-styles" in local_html_1
    assert "df-styles" in local_html_2
    assert "<style>" not in local_html_1
    assert "<style>" not in local_html_2


def test_fill_null_basic(null_df):
    """Test basic fill_null functionality with a single value."""
    # Fill all nulls with 0
    filled_df = null_df.fill_null(0)

    result = filled_df.collect()[0]

    # Check that nulls were filled with 0 (or equivalent)
    assert result.column(0) == pa.array([1, 0, 3, 0])
    assert result.column(1) == pa.array([4.5, 6.7, 0.0, 0.0])
    # String column should be filled with "0"
    assert result.column(2) == pa.array(["a", "0", "c", "0"])
    # Boolean column should be filled with False (0 converted to bool)
    assert result.column(3) == pa.array([True, False, False, False])


def test_fill_null_subset(null_df):
    """Test filling nulls only in a subset of columns."""
    # Fill nulls only in numeric columns
    filled_df = null_df.fill_null(0, subset=["int_col", "float_col"])

    result = filled_df.collect()[0]

    # Check that nulls were filled only in specified columns
    assert result.column(0) == pa.array([1, 0, 3, 0])
    assert result.column(1) == pa.array([4.5, 6.7, 0.0, 0.0])
    # These should still have nulls
    assert None in result.column(2).to_pylist()
    assert None in result.column(3).to_pylist()


def test_fill_null_str_column(null_df):
    """Test filling nulls in string columns with different values."""
    # Fill string nulls with a replacement string
    filled_df = null_df.fill_null("N/A", subset=["str_col"])

    result = filled_df.collect()[0]

    # Check that string nulls were filled with "N/A"
    assert result.column(2).to_pylist() == ["a", "N/A", "c", "N/A"]

    # Other columns should be unchanged
    assert None in result.column(0).to_pylist()
    assert None in result.column(1).to_pylist()
    assert None in result.column(3).to_pylist()

    # Fill with an empty string
    filled_df = null_df.fill_null("", subset=["str_col"])
    result = filled_df.collect()[0]
    assert result.column(2).to_pylist() == ["a", "", "c", ""]


def test_fill_null_bool_column(null_df):
    """Test filling nulls in boolean columns with different values."""
    # Fill bool nulls with True
    filled_df = null_df.fill_null(value=True, subset=["bool_col"])

    result = filled_df.collect()[0]

    # Check that bool nulls were filled with True
    assert result.column(3).to_pylist() == [True, True, False, True]

    # Other columns should be unchanged
    assert None in result.column(0).to_pylist()

    # Fill bool nulls with False
    filled_df = null_df.fill_null(value=False, subset=["bool_col"])
    result = filled_df.collect()[0]
    assert result.column(3).to_pylist() == [True, False, False, False]


def test_fill_null_date32_column(null_df):
    """Test filling nulls in date32 columns."""

    # Fill date32 nulls with a specific date (1970-01-01)
    epoch_date = datetime.date(1970, 1, 1)
    filled_df = null_df.fill_null(epoch_date, subset=["date32_col"])

    result = filled_df.collect()[0]

    # Check that date32 nulls were filled with epoch date
    dates = result.column(4).to_pylist()
    assert dates[0] == datetime.date(2000, 1, 1)  # Original value
    assert dates[1] == epoch_date  # Filled value
    assert dates[2] == datetime.date(2022, 1, 1)  # Original value
    assert dates[3] == epoch_date  # Filled value

    # Other date column should be unchanged
    assert None in result.column(5).to_pylist()


def test_fill_null_date64_column(null_df):
    """Test filling nulls in date64 columns."""

    # Fill date64 nulls with a specific date (1970-01-01)
    epoch_date = datetime.date(1970, 1, 1)
    filled_df = null_df.fill_null(epoch_date, subset=["date64_col"])

    result = filled_df.collect()[0]

    # Check that date64 nulls were filled with epoch date
    dates = result.column(5).to_pylist()
    assert dates[0] == datetime.date(2000, 1, 1)  # Original value
    assert dates[1] == epoch_date  # Filled value
    assert dates[2] == datetime.date(2022, 1, 1)  # Original value
    assert dates[3] == epoch_date  # Filled value

    # Other date column should be unchanged
    assert None in result.column(4).to_pylist()


def test_fill_null_type_coercion(null_df):
    """Test type coercion when filling nulls with values of different types."""
    # Try to fill string nulls with a number
    filled_df = null_df.fill_null(42, subset=["str_col"])

    result = filled_df.collect()[0]

    # String nulls should be filled with string representation of the number
    assert result.column(2).to_pylist() == ["a", "42", "c", "42"]

    # Try to fill bool nulls with a string that converts to True
    filled_df = null_df.fill_null("true", subset=["bool_col"])
    result = filled_df.collect()[0]

    # This behavior depends on the implementation - check it works without error
    # but don't make assertions about exact conversion behavior
    assert None not in result.column(3).to_pylist()


def test_fill_null_multiple_date_columns(null_df):
    """Test filling nulls in both date column types simultaneously."""

    # Fill both date column types with the same date
    test_date = datetime.date(2023, 12, 31)
    filled_df = null_df.fill_null(test_date, subset=["date32_col", "date64_col"])

    result = filled_df.collect()[0]

    # Check both date columns were filled correctly
    date32_vals = result.column(4).to_pylist()
    date64_vals = result.column(5).to_pylist()

    assert None not in date32_vals
    assert None not in date64_vals

    assert date32_vals[1] == test_date
    assert date32_vals[3] == test_date
    assert date64_vals[1] == test_date
    assert date64_vals[3] == test_date


def test_fill_null_specific_types(null_df):
    """Test filling nulls with type-appropriate values."""
    # Fill with type-specific values
    filled_df = null_df.fill_null("missing")

    result = filled_df.collect()[0]

    # Check that nulls were filled appropriately by type

    assert result.column(0).to_pylist() == [1, None, 3, None]
    assert result.column(1).to_pylist() == [4.5, 6.7, None, None]
    assert result.column(2).to_pylist() == ["a", "missing", "c", "missing"]
    assert result.column(3).to_pylist() == [True, None, False, None]  # Bool gets False
    assert result.column(4).to_pylist() == [
        datetime.date(2000, 1, 1),
        None,
        datetime.date(2022, 1, 1),
        None,
    ]
    assert result.column(5).to_pylist() == [
        datetime.date(2000, 1, 1),
        None,
        datetime.date(2022, 1, 1),
        None,
    ]


def test_fill_null_immutability(null_df):
    """Test that original DataFrame is unchanged after fill_null."""
    # Get original values with nulls
    original = null_df.collect()[0]
    original_int_nulls = original.column(0).to_pylist().count(None)

    # Apply fill_null
    _filled_df = null_df.fill_null(0)

    # Check that original is unchanged
    new_original = null_df.collect()[0]
    new_original_int_nulls = new_original.column(0).to_pylist().count(None)

    assert original_int_nulls == new_original_int_nulls
    assert original_int_nulls > 0  # Ensure we actually had nulls in the first place


def test_fill_null_empty_df(ctx):
    """Test fill_null on empty DataFrame."""
    # Create an empty DataFrame with schema
    batch = pa.RecordBatch.from_arrays(
        [pa.array([], type=pa.int64()), pa.array([], type=pa.string())],
        names=["a", "b"],
    )
    empty_df = ctx.create_dataframe([[batch]])

    # Fill nulls (should work without errors)
    filled_df = empty_df.fill_null(0)

    # Should still be empty but with same schema
    result = filled_df.collect()[0]
    assert len(result.column(0)) == 0
    assert len(result.column(1)) == 0
    assert result.schema.field(0).name == "a"
    assert result.schema.field(1).name == "b"


def test_fill_null_all_null_column(ctx):
    """Test fill_null on a column with all nulls."""
    # Create DataFrame with a column of all nulls
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([None, None, None], type=pa.string())],
        names=["a", "b"],
    )
    all_null_df = ctx.create_dataframe([[batch]])

    # Fill nulls with a value
    filled_df = all_null_df.fill_null("filled")

    # Check that all nulls were filled
    result = filled_df.collect()[0]
    assert result.column(1).to_pylist() == ["filled", "filled", "filled"]


def test_collect_interrupted():
    """Test that a long-running query can be interrupted with Ctrl-C.

    This test simulates a Ctrl-C keyboard interrupt by raising a KeyboardInterrupt
    exception in the main thread during a long-running query execution.
    """
    # Create a context and a DataFrame with a query that will run for a while
    ctx = SessionContext()

    # Create a recursive computation that will run for some time
    batches = []
    for i in range(10):
        batch = pa.RecordBatch.from_arrays(
            [
                pa.array(list(range(i * 1000, (i + 1) * 1000))),
                pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
            ],
            names=["a", "b"],
        )
        batches.append(batch)

    # Register tables
    ctx.register_record_batches("t1", [batches])
    ctx.register_record_batches("t2", [batches])

    # Create a large join operation that will take time to process
    df = ctx.sql("""
        WITH t1_expanded AS (
            SELECT
                a,
                b,
                CAST(a AS DOUBLE) / 1.5 AS c,
                CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
            FROM t1
            CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
        ),
        t2_expanded AS (
            SELECT
                a,
                b,
                CAST(a AS DOUBLE) * 2.5 AS e,
                CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
            FROM t2
            CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
        )
        SELECT
            t1.a, t1.b, t1.c, t1.d,
            t2.a AS a2, t2.b AS b2, t2.e, t2.f
        FROM t1_expanded t1
        JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
        WHERE t1.a > 100 AND t2.a > 100
    """)

    # Flag to track if the query was interrupted
    interrupted = False
    interrupt_error = None
    main_thread = threading.main_thread()

    # Shared flag to indicate query execution has started
    query_started = threading.Event()
    max_wait_time = 5.0  # Maximum wait time in seconds

    # This function will be run in a separate thread and will raise
    # KeyboardInterrupt in the main thread
    def trigger_interrupt():
        """Poll for query start, then raise KeyboardInterrupt in the main thread"""
        # Poll for query to start with small sleep intervals
        start_time = time.time()
        while not query_started.is_set():
            time.sleep(0.1)  # Small sleep between checks
            if time.time() - start_time > max_wait_time:
                msg = f"Query did not start within {max_wait_time} seconds"
                raise RuntimeError(msg)

        # Check if thread ID is available
        thread_id = main_thread.ident
        if thread_id is None:
            msg = "Cannot get main thread ID"
            raise RuntimeError(msg)

        # Use ctypes to raise exception in main thread
        exception = ctypes.py_object(KeyboardInterrupt)
        res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
            ctypes.c_long(thread_id), exception
        )
        if res != 1:
            # If res is 0, the thread ID was invalid
            # If res > 1, we modified multiple threads
            ctypes.pythonapi.PyThreadState_SetAsyncExc(
                ctypes.c_long(thread_id), ctypes.py_object(0)
            )
            msg = "Failed to raise KeyboardInterrupt in main thread"
            raise RuntimeError(msg)

    # Start a thread to trigger the interrupt
    interrupt_thread = threading.Thread(target=trigger_interrupt)
    # we mark as daemon so the test process can exit even if this thread doesn't finish
    interrupt_thread.daemon = True
    interrupt_thread.start()

    # Execute the query and expect it to be interrupted
    try:
        # Signal that we're about to start the query
        query_started.set()
        df.collect()
    except KeyboardInterrupt:
        interrupted = True
    except Exception as e:
        interrupt_error = e

    # Assert that the query was interrupted properly
    if not interrupted:
        pytest.fail(f"Query was not interrupted; got error: {interrupt_error}")

    # Make sure the interrupt thread has finished
    interrupt_thread.join(timeout=1.0)


def test_arrow_c_stream_interrupted():
    """__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.

    Similar to ``test_collect_interrupted`` this test issues a long running
    query, but consumes the results via ``__arrow_c_stream__``. It then raises
    ``KeyboardInterrupt`` in the main thread and verifies that the stream
    iteration stops promptly with the appropriate exception.
    """

    ctx = SessionContext()

    batches = []
    for i in range(10):
        batch = pa.RecordBatch.from_arrays(
            [
                pa.array(list(range(i * 1000, (i + 1) * 1000))),
                pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
            ],
            names=["a", "b"],
        )
        batches.append(batch)

    ctx.register_record_batches("t1", [batches])
    ctx.register_record_batches("t2", [batches])

    df = ctx.sql(
        """
        WITH t1_expanded AS (
            SELECT
                a,
                b,
                CAST(a AS DOUBLE) / 1.5 AS c,
                CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
            FROM t1
            CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
        ),
        t2_expanded AS (
            SELECT
                a,
                b,
                CAST(a AS DOUBLE) * 2.5 AS e,
                CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
            FROM t2
            CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
        )
        SELECT
            t1.a, t1.b, t1.c, t1.d,
            t2.a AS a2, t2.b AS b2, t2.e, t2.f
        FROM t1_expanded t1
        JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
        WHERE t1.a > 100 AND t2.a > 100
        """
    )

    reader = pa.RecordBatchReader.from_stream(df)

    interrupted = False
    interrupt_error = None
    query_started = threading.Event()
    max_wait_time = 5.0

    def trigger_interrupt():
        start_time = time.time()
        while not query_started.is_set():
            time.sleep(0.1)
            if time.time() - start_time > max_wait_time:
                msg = f"Query did not start within {max_wait_time} seconds"
                raise RuntimeError(msg)

        thread_id = threading.main_thread().ident
        if thread_id is None:
            msg = "Cannot get main thread ID"
            raise RuntimeError(msg)

        exception = ctypes.py_object(KeyboardInterrupt)
        res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
            ctypes.c_long(thread_id), exception
        )
        if res != 1:
            ctypes.pythonapi.PyThreadState_SetAsyncExc(
                ctypes.c_long(thread_id), ctypes.py_object(0)
            )
            msg = "Failed to raise KeyboardInterrupt in main thread"
            raise RuntimeError(msg)

    interrupt_thread = threading.Thread(target=trigger_interrupt)
    interrupt_thread.daemon = True
    interrupt_thread.start()

    try:
        query_started.set()
        # consume the reader which should block and be interrupted
        reader.read_all()
    except KeyboardInterrupt:
        interrupted = True
    except Exception as e:  # pragma: no cover - unexpected errors
        interrupt_error = e

    if not interrupted:
        pytest.fail(f"Stream was not interrupted; got error: {interrupt_error}")

    interrupt_thread.join(timeout=1.0)


def test_show_select_where_no_rows(capsys) -> None:
    ctx = SessionContext()
    df = ctx.sql("SELECT 1 WHERE 1=0")
    df.show()
    out = capsys.readouterr().out
    assert "DataFrame has no rows" in out


def test_show_from_empty_batch(capsys) -> None:
    ctx = SessionContext()
    batch = pa.record_batch([pa.array([], type=pa.int32())], names=["a"])
    ctx.create_dataframe([[batch]]).show()
    out = capsys.readouterr().out
    assert "| a |" in out


@pytest.mark.parametrize("file_sort_order", [[["a"]], [[df_col("a")]]])
def test_register_parquet_file_sort_order(ctx, tmp_path, file_sort_order):
    table = pa.table({"a": [1, 2]})
    path = tmp_path / "file.parquet"
    pa.parquet.write_table(table, path)
    ctx.register_parquet("t", path, file_sort_order=file_sort_order)
    assert "t" in ctx.catalog().schema().names()


@pytest.mark.parametrize("file_sort_order", [[["a"]], [[df_col("a")]]])
def test_register_listing_table_file_sort_order(ctx, tmp_path, file_sort_order):
    table = pa.table({"a": [1, 2]})
    dir_path = tmp_path / "dir"
    dir_path.mkdir()
    pa.parquet.write_table(table, dir_path / "file.parquet")
    ctx.register_listing_table(
        "t", dir_path, schema=table.schema, file_sort_order=file_sort_order
    )
    assert "t" in ctx.catalog().schema().names()


@pytest.mark.parametrize("file_sort_order", [[["a"]], [[df_col("a")]]])
def test_read_parquet_file_sort_order(tmp_path, file_sort_order):
    ctx = SessionContext()
    table = pa.table({"a": [1, 2]})
    path = tmp_path / "data.parquet"
    pa.parquet.write_table(table, path)
    df = ctx.read_parquet(path, file_sort_order=file_sort_order)
    assert df.collect()[0].column(0).to_pylist() == [1, 2]
