# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime as dt
import gzip
import pathlib

import pyarrow as pa
import pyarrow.dataset as ds
import pytest
from datafusion import (
    DataFrame,
    RuntimeEnvBuilder,
    SessionConfig,
    SessionContext,
    SQLOptions,
    Table,
    column,
    literal,
)


def test_create_context_no_args():
    SessionContext()


def test_create_context_session_config_only():
    SessionContext(config=SessionConfig())


def test_create_context_runtime_config_only():
    SessionContext(runtime=RuntimeEnvBuilder())


@pytest.mark.parametrize("path_to_str", [True, False])
def test_runtime_configs(tmp_path, path_to_str):
    path1 = tmp_path / "dir1"
    path2 = tmp_path / "dir2"

    path1 = str(path1) if path_to_str else path1
    path2 = str(path2) if path_to_str else path2

    runtime = RuntimeEnvBuilder().with_disk_manager_specified(path1, path2)
    config = SessionConfig().with_default_catalog_and_schema("foo", "bar")
    ctx = SessionContext(config, runtime)
    assert ctx is not None

    db = ctx.catalog("foo").schema("bar")
    assert db is not None


@pytest.mark.parametrize("path_to_str", [True, False])
def test_temporary_files(tmp_path, path_to_str):
    path = str(tmp_path) if path_to_str else tmp_path

    runtime = RuntimeEnvBuilder().with_temp_file_path(path)
    config = SessionConfig().with_default_catalog_and_schema("foo", "bar")
    ctx = SessionContext(config, runtime)
    assert ctx is not None

    db = ctx.catalog("foo").schema("bar")
    assert db is not None


def test_create_context_with_all_valid_args():
    runtime = RuntimeEnvBuilder().with_disk_manager_os().with_fair_spill_pool(10000000)
    config = (
        SessionConfig()
        .with_create_default_catalog_and_schema(enabled=True)
        .with_default_catalog_and_schema("foo", "bar")
        .with_target_partitions(1)
        .with_information_schema(enabled=True)
        .with_repartition_joins(enabled=False)
        .with_repartition_aggregations(enabled=False)
        .with_repartition_windows(enabled=False)
        .with_parquet_pruning(enabled=False)
    )

    ctx = SessionContext(config, runtime)

    # verify that at least some of the arguments worked
    ctx.catalog("foo").schema("bar")
    with pytest.raises(KeyError):
        ctx.catalog("datafusion")


def test_register_record_batches(ctx):
    # create a RecordBatch and register it as memtable
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )

    ctx.register_record_batches("t", [[batch]])

    assert ctx.catalog().schema().names() == {"t"}

    result = ctx.sql("SELECT a+b, a-b FROM t").collect()

    assert result[0].column(0) == pa.array([5, 7, 9])
    assert result[0].column(1) == pa.array([-3, -3, -3])


def test_create_dataframe_registers_unique_table_name(ctx):
    # create a RecordBatch and register it as memtable
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )

    df = ctx.create_dataframe([[batch]])
    tables = list(ctx.catalog().schema().names())

    assert df
    assert len(tables) == 1
    assert len(tables[0]) == 33
    assert tables[0].startswith("c")
    # ensure that the rest of the table name contains
    # only hexadecimal numbers
    for c in tables[0][1:]:
        assert c in "0123456789abcdef"


def test_create_dataframe_registers_with_defined_table_name(ctx):
    # create a RecordBatch and register it as memtable
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )

    df = ctx.create_dataframe([[batch]], name="tbl")
    tables = list(ctx.catalog().schema().names())

    assert df
    assert len(tables) == 1
    assert tables[0] == "tbl"


def test_from_arrow_table(ctx):
    # create a PyArrow table
    data = {"a": [1, 2, 3], "b": [4, 5, 6]}
    table = pa.Table.from_pydict(data)

    # convert to DataFrame
    df = ctx.from_arrow(table)
    tables = list(ctx.catalog().schema().names())

    assert df
    assert len(tables) == 1
    assert isinstance(df, DataFrame)
    assert set(df.schema().names) == {"a", "b"}
    assert df.collect()[0].num_rows == 3


def record_batch_generator(num_batches: int):
    schema = pa.schema([("a", pa.int64()), ("b", pa.int64())])
    for _i in range(num_batches):
        yield pa.RecordBatch.from_arrays(
            [pa.array([1, 2, 3]), pa.array([4, 5, 6])], schema=schema
        )


@pytest.mark.parametrize(
    "source",
    [
        # __arrow_c_array__ sources
        pa.array([{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}]),
        # __arrow_c_stream__ sources
        pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}),
        pa.RecordBatchReader.from_batches(
            pa.schema([("a", pa.int64()), ("b", pa.int64())]), record_batch_generator(1)
        ),
        pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}),
    ],
)
def test_from_arrow_sources(ctx, source) -> None:
    df = ctx.from_arrow(source)
    assert df
    assert isinstance(df, DataFrame)
    assert df.schema().names == ["a", "b"]
    assert df.count() == 3


def test_from_arrow_table_with_name(ctx):
    # create a PyArrow table
    data = {"a": [1, 2, 3], "b": [4, 5, 6]}
    table = pa.Table.from_pydict(data)

    # convert to DataFrame with optional name
    df = ctx.from_arrow(table, name="tbl")
    tables = list(ctx.catalog().schema().names())

    assert df
    assert tables[0] == "tbl"


def test_from_arrow_table_empty(ctx):
    data = {"a": [], "b": []}
    schema = pa.schema([("a", pa.int32()), ("b", pa.string())])
    table = pa.Table.from_pydict(data, schema=schema)

    # convert to DataFrame
    df = ctx.from_arrow(table)
    tables = list(ctx.catalog().schema().names())

    assert df
    assert len(tables) == 1
    assert isinstance(df, DataFrame)
    assert set(df.schema().names) == {"a", "b"}
    assert len(df.collect()) == 0


def test_from_arrow_table_empty_no_schema(ctx):
    data = {"a": [], "b": []}
    table = pa.Table.from_pydict(data)

    # convert to DataFrame
    df = ctx.from_arrow(table)
    tables = list(ctx.catalog().schema().names())

    assert df
    assert len(tables) == 1
    assert isinstance(df, DataFrame)
    assert set(df.schema().names) == {"a", "b"}
    assert len(df.collect()) == 0


def test_from_pylist(ctx):
    # create a dataframe from Python list
    data = [
        {"a": 1, "b": 4},
        {"a": 2, "b": 5},
        {"a": 3, "b": 6},
    ]

    df = ctx.from_pylist(data)
    tables = list(ctx.catalog().schema().names())

    assert df
    assert len(tables) == 1
    assert isinstance(df, DataFrame)
    assert set(df.schema().names) == {"a", "b"}
    assert df.collect()[0].num_rows == 3


def test_from_pydict(ctx):
    # create a dataframe from Python dictionary
    data = {"a": [1, 2, 3], "b": [4, 5, 6]}

    df = ctx.from_pydict(data)
    tables = list(ctx.catalog().schema().names())

    assert df
    assert len(tables) == 1
    assert isinstance(df, DataFrame)
    assert set(df.schema().names) == {"a", "b"}
    assert df.collect()[0].num_rows == 3


def test_from_pandas(ctx):
    # create a dataframe from pandas dataframe
    pd = pytest.importorskip("pandas")
    data = {"a": [1, 2, 3], "b": [4, 5, 6]}
    pandas_df = pd.DataFrame(data)

    df = ctx.from_pandas(pandas_df)
    tables = list(ctx.catalog().schema().names())

    assert df
    assert len(tables) == 1
    assert isinstance(df, DataFrame)
    assert set(df.schema().names) == {"a", "b"}
    assert df.collect()[0].num_rows == 3


def test_from_polars(ctx):
    # create a dataframe from Polars dataframe
    pd = pytest.importorskip("polars")
    data = {"a": [1, 2, 3], "b": [4, 5, 6]}
    polars_df = pd.DataFrame(data)

    df = ctx.from_polars(polars_df)
    tables = list(ctx.catalog().schema().names())

    assert df
    assert len(tables) == 1
    assert isinstance(df, DataFrame)
    assert set(df.schema().names) == {"a", "b"}
    assert df.collect()[0].num_rows == 3


def test_register_table(ctx, database):
    default = ctx.catalog()
    public = default.schema("public")
    assert public.names() == {"csv", "csv1", "csv2"}
    table = public.table("csv")

    ctx.register_table("csv3", table)
    assert public.names() == {"csv", "csv1", "csv2", "csv3"}


def test_read_table_from_catalog(ctx, database):
    default = ctx.catalog()
    public = default.schema("public")
    assert public.names() == {"csv", "csv1", "csv2"}

    table = public.table("csv")
    table_df = ctx.read_table(table)
    table_df.show()


def test_read_table_from_df(ctx):
    df = ctx.from_pydict({"a": [1, 2]})
    result = ctx.read_table(df).collect()
    assert [b.to_pydict() for b in result] == [{"a": [1, 2]}]


def test_read_table_from_dataset(ctx):
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    dataset = ds.dataset([batch])

    result = ctx.read_table(dataset).collect()

    assert result[0].column(0) == pa.array([1, 2, 3])
    assert result[0].column(1) == pa.array([4, 5, 6])


def test_deregister_table(ctx, database):
    default = ctx.catalog()
    public = default.schema("public")
    assert public.names() == {"csv", "csv1", "csv2"}

    ctx.deregister_table("csv")
    assert public.names() == {"csv1", "csv2"}


def test_register_table_from_dataframe(ctx):
    df = ctx.from_pydict({"a": [1, 2]})
    ctx.register_table("df_tbl", df)
    result = ctx.sql("SELECT * FROM df_tbl").collect()
    assert [b.to_pydict() for b in result] == [{"a": [1, 2]}]


@pytest.mark.parametrize("temporary", [True, False])
def test_register_table_from_dataframe_into_view(ctx, temporary):
    df = ctx.from_pydict({"a": [1, 2]})
    table = df.into_view(temporary=temporary)
    assert isinstance(table, Table)
    if temporary:
        assert table.kind == "temporary"
    else:
        assert table.kind == "view"

    ctx.register_table("view_tbl", table)
    result = ctx.sql("SELECT * FROM view_tbl").collect()
    assert [b.to_pydict() for b in result] == [{"a": [1, 2]}]


def test_table_from_dataframe(ctx):
    df = ctx.from_pydict({"a": [1, 2]})
    table = Table(df)
    assert isinstance(table, Table)
    ctx.register_table("from_dataframe_tbl", table)
    result = ctx.sql("SELECT * FROM from_dataframe_tbl").collect()
    assert [b.to_pydict() for b in result] == [{"a": [1, 2]}]


def test_table_from_dataframe_internal(ctx):
    df = ctx.from_pydict({"a": [1, 2]})
    table = Table(df.df)
    assert isinstance(table, Table)
    ctx.register_table("from_internal_dataframe_tbl", table)
    result = ctx.sql("SELECT * FROM from_internal_dataframe_tbl").collect()
    assert [b.to_pydict() for b in result] == [{"a": [1, 2]}]


def test_register_dataset(ctx):
    # create a RecordBatch and register it as a pyarrow.dataset.Dataset
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    dataset = ds.dataset([batch])
    ctx.register_dataset("t", dataset)

    assert ctx.catalog().schema().names() == {"t"}

    result = ctx.sql("SELECT a+b, a-b FROM t").collect()

    assert result[0].column(0) == pa.array([5, 7, 9])
    assert result[0].column(1) == pa.array([-3, -3, -3])


def test_dataset_filter(ctx, capfd):
    # create a RecordBatch and register it as a pyarrow.dataset.Dataset
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    dataset = ds.dataset([batch])
    ctx.register_dataset("t", dataset)

    assert ctx.catalog().schema().names() == {"t"}
    df = ctx.sql("SELECT a+b, a-b FROM t WHERE a BETWEEN 2 and 3 AND b > 5")

    # Make sure the filter was pushed down in Physical Plan
    df.explain()
    captured = capfd.readouterr()
    assert "filter_expr=(((a >= 2) and (a <= 3)) and (b > 5))" in captured.out

    result = df.collect()

    assert result[0].column(0) == pa.array([9])
    assert result[0].column(1) == pa.array([-3])


def test_dataset_count(ctx):
    # `datafusion-python` issue: https://github.com/apache/datafusion-python/issues/800
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    dataset = ds.dataset([batch])
    ctx.register_dataset("t", dataset)

    # Testing the dataframe API
    df = ctx.table("t")
    assert df.count() == 3

    # Testing the SQL API
    count = ctx.sql("SELECT COUNT(*) FROM t")
    count = count.collect()
    assert count[0].column(0) == pa.array([3])


def test_pyarrow_predicate_pushdown_is_null(ctx, capfd):
    """Ensure that pyarrow filter gets pushed down for `IsNull`"""
    # create a RecordBatch and register it as a pyarrow.dataset.Dataset
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([7, None, 9])],
        names=["a", "b", "c"],
    )
    dataset = ds.dataset([batch])
    ctx.register_dataset("t", dataset)
    # Make sure the filter was pushed down in Physical Plan
    df = ctx.sql("SELECT a FROM t WHERE c is NULL")
    df.explain()
    captured = capfd.readouterr()
    assert "filter_expr=is_null(c, {nan_is_null=false})" in captured.out

    result = df.collect()
    assert result[0].column(0) == pa.array([2])


def test_pyarrow_predicate_pushdown_timestamp(ctx, tmpdir, capfd):
    """Ensure that pyarrow filter gets pushed down for timestamp"""
    # Ref: https://github.com/apache/datafusion-python/issues/703

    # create pyarrow dataset with no actual files
    col_type = pa.timestamp("ns", "+00:00")
    nyd_2000 = pa.scalar(dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc), col_type)
    pa_dataset_fs = pa.fs.SubTreeFileSystem(str(tmpdir), pa.fs.LocalFileSystem())
    pa_dataset_format = pa.dataset.ParquetFileFormat()
    pa_dataset_partition = pa.dataset.field("a") <= nyd_2000
    fragments = [
        # NOTE: we never actually make this file.
        # Working predicate pushdown means it never gets accessed
        pa_dataset_format.make_fragment(
            "1.parquet",
            filesystem=pa_dataset_fs,
            partition_expression=pa_dataset_partition,
        )
    ]
    pa_dataset = pa.dataset.FileSystemDataset(
        fragments,
        pa.schema([pa.field("a", col_type)]),
        pa_dataset_format,
        pa_dataset_fs,
    )

    ctx.register_dataset("t", pa_dataset)

    # the partition for our only fragment is for a < 2000-01-01.
    # so querying for a > 2024-01-01 should not touch any files
    df = ctx.sql("SELECT * FROM t WHERE a > '2024-01-01T00:00:00+00:00'")
    assert df.collect() == []


def test_dataset_filter_nested_data(ctx):
    # create Arrow StructArrays to test nested data types
    data = pa.StructArray.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    batch = pa.RecordBatch.from_arrays(
        [data],
        names=["nested_data"],
    )
    dataset = ds.dataset([batch])
    ctx.register_dataset("t", dataset)

    assert ctx.catalog().schema().names() == {"t"}

    df = ctx.table("t")

    # This filter will not be pushed down to DatasetExec since it
    # isn't supported
    df = df.filter(column("nested_data")["b"] > literal(5)).select(
        column("nested_data")["a"] + column("nested_data")["b"],
        column("nested_data")["a"] - column("nested_data")["b"],
    )

    result = df.collect()

    assert result[0].column(0) == pa.array([9])
    assert result[0].column(1) == pa.array([-3])


def test_table_exist(ctx):
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    dataset = ds.dataset([batch])
    ctx.register_dataset("t", dataset)

    assert ctx.table_exist("t") is True


def test_table_not_found(ctx):
    from uuid import uuid4

    with pytest.raises(KeyError):
        ctx.table(f"not-found-{uuid4()}")


def test_read_json(ctx):
    path = pathlib.Path(__file__).parent.resolve()

    # Default
    test_data_path = path / "data_test_context" / "data.json"
    df = ctx.read_json(test_data_path)
    result = df.collect()

    assert result[0].column(0) == pa.array(["a", "b", "c"])
    assert result[0].column(1) == pa.array([1, 2, 3])

    # Schema
    schema = pa.schema(
        [
            pa.field("A", pa.string(), nullable=True),
        ]
    )
    df = ctx.read_json(test_data_path, schema=schema)
    result = df.collect()

    assert result[0].column(0) == pa.array(["a", "b", "c"])
    assert result[0].schema == schema

    # File extension
    test_data_path = path / "data_test_context" / "data.json"
    df = ctx.read_json(test_data_path, file_extension=".json")
    result = df.collect()

    assert result[0].column(0) == pa.array(["a", "b", "c"])
    assert result[0].column(1) == pa.array([1, 2, 3])


def test_read_json_compressed(ctx, tmp_path):
    path = pathlib.Path(__file__).parent.resolve()
    test_data_path = path / "data_test_context" / "data.json"

    # File compression type
    gzip_path = tmp_path / "data.json.gz"

    with (
        pathlib.Path.open(test_data_path, "rb") as csv_file,
        gzip.open(gzip_path, "wb") as gzipped_file,
    ):
        gzipped_file.writelines(csv_file)

    df = ctx.read_json(gzip_path, file_extension=".gz", file_compression_type="gz")
    result = df.collect()

    assert result[0].column(0) == pa.array(["a", "b", "c"])
    assert result[0].column(1) == pa.array([1, 2, 3])


def test_read_csv(ctx):
    csv_df = ctx.read_csv(path="testing/data/csv/aggregate_test_100.csv")
    csv_df.select(column("c1")).show()


def test_read_csv_list(ctx):
    csv_df = ctx.read_csv(path=["testing/data/csv/aggregate_test_100.csv"])
    expected = csv_df.count() * 2

    double_csv_df = ctx.read_csv(
        path=[
            "testing/data/csv/aggregate_test_100.csv",
            "testing/data/csv/aggregate_test_100.csv",
        ]
    )
    actual = double_csv_df.count()

    double_csv_df.select(column("c1")).show()
    assert actual == expected


def test_read_csv_compressed(ctx, tmp_path):
    test_data_path = pathlib.Path("testing/data/csv/aggregate_test_100.csv")

    # File compression type
    gzip_path = tmp_path / "aggregate_test_100.csv.gz"

    with (
        pathlib.Path.open(test_data_path, "rb") as csv_file,
        gzip.open(gzip_path, "wb") as gzipped_file,
    ):
        gzipped_file.writelines(csv_file)

    csv_df = ctx.read_csv(gzip_path, file_extension=".gz", file_compression_type="gz")
    csv_df.select(column("c1")).show()


def test_read_parquet(ctx):
    parquet_df = ctx.read_parquet(path="parquet/data/alltypes_plain.parquet")
    parquet_df.show()
    assert parquet_df is not None

    path = pathlib.Path.cwd() / "parquet/data/alltypes_plain.parquet"
    parquet_df = ctx.read_parquet(path=path)
    assert parquet_df is not None


def test_read_avro(ctx):
    avro_df = ctx.read_avro(path="testing/data/avro/alltypes_plain.avro")
    avro_df.show()
    assert avro_df is not None

    path = pathlib.Path.cwd() / "testing/data/avro/alltypes_plain.avro"
    avro_df = ctx.read_avro(path=path)
    assert avro_df is not None


def test_create_sql_options():
    SQLOptions()


def test_sql_with_options_no_ddl(ctx):
    sql = "CREATE TABLE IF NOT EXISTS valuetable AS VALUES(1,'HELLO'),(12,'DATAFUSION')"
    ctx.sql(sql)
    options = SQLOptions().with_allow_ddl(allow=False)
    with pytest.raises(Exception, match="DDL"):
        ctx.sql_with_options(sql, options=options)


def test_sql_with_options_no_dml(ctx):
    table_name = "t"
    batch = pa.RecordBatch.from_arrays(
        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
        names=["a", "b"],
    )
    dataset = ds.dataset([batch])
    ctx.register_dataset(table_name, dataset)
    sql = f'INSERT INTO "{table_name}" VALUES (1, 2), (2, 3);'
    ctx.sql(sql)
    options = SQLOptions().with_allow_dml(allow=False)
    with pytest.raises(Exception, match="DML"):
        ctx.sql_with_options(sql, options=options)


def test_sql_with_options_no_statements(ctx):
    sql = "SET time zone = 1;"
    ctx.sql(sql)
    options = SQLOptions().with_allow_statements(allow=False)
    with pytest.raises(Exception, match="SetVariable"):
        ctx.sql_with_options(sql, options=options)


@pytest.fixture
def batch():
    return pa.RecordBatch.from_arrays(
        [pa.array([4, 5, 6])],
        names=["a"],
    )


def test_create_dataframe_with_global_ctx(batch):
    ctx = SessionContext.global_ctx()

    df = ctx.create_dataframe([[batch]])

    result = df.collect()[0].column(0)

    assert result == pa.array([4, 5, 6])
