# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import math
from datetime import datetime, timezone

import numpy as np
import pyarrow as pa
import pytest
from datafusion import SessionContext, column, literal, string_literal
from datafusion import functions as f

np.seterr(invalid="ignore")

DEFAULT_TZ = timezone.utc


@pytest.fixture
def df():
    ctx = SessionContext()
    # create a RecordBatch and a new DataFrame from it
    batch = pa.RecordBatch.from_arrays(
        [
            pa.array(["Hello", "World", "!"], type=pa.string_view()),
            pa.array([4, 5, 6]),
            pa.array(["hello ", " world ", " !"], type=pa.string_view()),
            pa.array(
                [
                    datetime(2022, 12, 31, tzinfo=DEFAULT_TZ),
                    datetime(2027, 6, 26, tzinfo=DEFAULT_TZ),
                    datetime(2020, 7, 2, tzinfo=DEFAULT_TZ),
                ]
            ),
            pa.array([False, True, True]),
        ],
        names=["a", "b", "c", "d", "e"],
    )
    return ctx.create_dataframe([[batch]])


def test_named_struct(df):
    df = df.with_column(
        "d",
        f.named_struct(
            [
                ("a", column("a")),
                ("b", column("b")),
                ("c", column("c")),
            ]
        ),
    )

    expected = """DataFrame()
+-------+---+---------+------------------------------+-------+
| a     | b | c       | d                            | e     |
+-------+---+---------+------------------------------+-------+
| Hello | 4 | hello   | {a: Hello, b: 4, c: hello }  | false |
| World | 5 |  world  | {a: World, b: 5, c:  world } | true  |
| !     | 6 |  !      | {a: !, b: 6, c:  !}          | true  |
+-------+---+---------+------------------------------+-------+
""".strip()
    assert str(df) == expected


def test_literal(df):
    df = df.select(
        literal(1),
        literal("1"),
        literal("OK"),
        literal(3.14),
        literal(value=True),
        literal(b"hello world"),
    )
    result = df.collect()
    assert len(result) == 1
    result = result[0]
    assert result.column(0) == pa.array([1] * 3)
    assert result.column(1) == pa.array(["1"] * 3, type=pa.string_view())
    assert result.column(2) == pa.array(["OK"] * 3, type=pa.string_view())
    assert result.column(3) == pa.array([3.14] * 3)
    assert result.column(4) == pa.array([True] * 3)
    assert result.column(5) == pa.array([b"hello world"] * 3)


def test_lit_arith(df):
    """Test literals with arithmetic operations"""
    df = df.select(
        literal(1) + column("b"), f.concat(column("a").cast(pa.string()), literal("!"))
    )
    result = df.collect()
    assert len(result) == 1
    result = result[0]

    assert result.column(0) == pa.array([5, 6, 7])
    assert result.column(1) == pa.array(
        ["Hello!", "World!", "!!"], type=pa.string_view()
    )


def test_math_functions():
    ctx = SessionContext()
    # create a RecordBatch and a new DataFrame from it
    batch = pa.RecordBatch.from_arrays(
        [pa.array([0.1, -0.7, 0.55]), pa.array([float("nan"), 0, 2.0])],
        names=["value", "na_value"],
    )
    df = ctx.create_dataframe([[batch]])

    values = np.array([0.1, -0.7, 0.55])
    na_values = np.array([np.nan, 0, 2.0])
    col_v = column("value")
    col_nav = column("na_value")
    df = df.select(
        f.abs(col_v),
        f.sin(col_v),
        f.cos(col_v),
        f.tan(col_v),
        f.asin(col_v),
        f.acos(col_v),
        f.exp(col_v),
        f.ln(col_v + literal(pa.scalar(1))),
        f.log2(col_v + literal(pa.scalar(1))),
        f.log10(col_v + literal(pa.scalar(1))),
        f.random(),
        f.atan(col_v),
        f.atan2(col_v, literal(pa.scalar(1.1))),
        f.ceil(col_v),
        f.floor(col_v),
        f.power(col_v, literal(pa.scalar(3))),
        f.pow(col_v, literal(pa.scalar(4))),
        f.round(col_v),
        f.round(col_v, literal(pa.scalar(3))),
        f.sqrt(col_v),
        f.signum(col_v),
        f.trunc(col_v),
        f.asinh(col_v),
        f.acosh(col_v),
        f.atanh(col_v),
        f.cbrt(col_v),
        f.cosh(col_v),
        f.degrees(col_v),
        f.gcd(literal(9), literal(3)),
        f.lcm(literal(6), literal(4)),
        f.nanvl(col_nav, literal(5)),
        f.pi(),
        f.radians(col_v),
        f.sinh(col_v),
        f.tanh(col_v),
        f.factorial(literal(6)),
        f.isnan(col_nav),
        f.iszero(col_nav),
        f.log(literal(3), col_v + literal(pa.scalar(1))),
    )
    batches = df.collect()
    assert len(batches) == 1
    result = batches[0]

    np.testing.assert_array_almost_equal(result.column(0), np.abs(values))
    np.testing.assert_array_almost_equal(result.column(1), np.sin(values))
    np.testing.assert_array_almost_equal(result.column(2), np.cos(values))
    np.testing.assert_array_almost_equal(result.column(3), np.tan(values))
    np.testing.assert_array_almost_equal(result.column(4), np.arcsin(values))
    np.testing.assert_array_almost_equal(result.column(5), np.arccos(values))
    np.testing.assert_array_almost_equal(result.column(6), np.exp(values))
    np.testing.assert_array_almost_equal(result.column(7), np.log(values + 1.0))
    np.testing.assert_array_almost_equal(result.column(8), np.log2(values + 1.0))
    np.testing.assert_array_almost_equal(result.column(9), np.log10(values + 1.0))
    np.testing.assert_array_less(result.column(10), np.ones_like(values))
    np.testing.assert_array_almost_equal(result.column(11), np.arctan(values))
    np.testing.assert_array_almost_equal(result.column(12), np.arctan2(values, 1.1))
    np.testing.assert_array_almost_equal(result.column(13), np.ceil(values))
    np.testing.assert_array_almost_equal(result.column(14), np.floor(values))
    np.testing.assert_array_almost_equal(result.column(15), np.power(values, 3))
    np.testing.assert_array_almost_equal(result.column(16), np.power(values, 4))
    np.testing.assert_array_almost_equal(result.column(17), np.round(values))
    np.testing.assert_array_almost_equal(result.column(18), np.round(values, 3))
    np.testing.assert_array_almost_equal(result.column(19), np.sqrt(values))
    np.testing.assert_array_almost_equal(result.column(20), np.sign(values))
    np.testing.assert_array_almost_equal(result.column(21), np.trunc(values))
    np.testing.assert_array_almost_equal(result.column(22), np.arcsinh(values))
    np.testing.assert_array_almost_equal(result.column(23), np.arccosh(values))
    np.testing.assert_array_almost_equal(result.column(24), np.arctanh(values))
    np.testing.assert_array_almost_equal(result.column(25), np.cbrt(values))
    np.testing.assert_array_almost_equal(result.column(26), np.cosh(values))
    np.testing.assert_array_almost_equal(result.column(27), np.degrees(values))
    np.testing.assert_array_almost_equal(result.column(28), np.gcd(9, 3))
    np.testing.assert_array_almost_equal(result.column(29), np.lcm(6, 4))
    np.testing.assert_array_almost_equal(
        result.column(30), np.where(np.isnan(na_values), 5, na_values)
    )
    np.testing.assert_array_almost_equal(result.column(31), np.pi)
    np.testing.assert_array_almost_equal(result.column(32), np.radians(values))
    np.testing.assert_array_almost_equal(result.column(33), np.sinh(values))
    np.testing.assert_array_almost_equal(result.column(34), np.tanh(values))
    np.testing.assert_array_almost_equal(result.column(35), math.factorial(6))
    np.testing.assert_array_almost_equal(result.column(36), np.isnan(na_values))
    np.testing.assert_array_almost_equal(result.column(37), na_values == 0)
    np.testing.assert_array_almost_equal(
        result.column(38), np.emath.logn(3, values + 1.0)
    )


def py_indexof(arr, v):
    try:
        return arr.index(v) + 1
    except ValueError:
        return np.nan


def py_arr_remove(arr, v, n=None):
    new_arr = arr[:]
    found = 0
    try:
        while found != n:
            new_arr.remove(v)
            found += 1
    except ValueError:
        pass

    return new_arr


def py_arr_replace(arr, from_, to, n=None):
    new_arr = arr[:]
    found = 0
    try:
        while found != n:
            idx = new_arr.index(from_)
            new_arr[idx] = to
            found += 1
    except ValueError:
        pass

    return new_arr


def py_arr_resize(arr, size, value):
    arr = np.asarray(arr)
    return np.pad(
        arr,
        [(0, size - arr.shape[0])],
        "constant",
        constant_values=value,
    )


def py_flatten(arr):
    result = []
    for elem in arr:
        if isinstance(elem, list):
            result.extend(py_flatten(elem))
        else:
            result.append(elem)
    return result


@pytest.mark.parametrize(
    ("stmt", "py_expr"),
    [
        (
            lambda col: f.array_append(col, literal(99.0)),
            lambda data: [np.append(arr, 99.0) for arr in data],
        ),
        (
            lambda col: f.array_push_back(col, literal(99.0)),
            lambda data: [np.append(arr, 99.0) for arr in data],
        ),
        (
            lambda col: f.list_append(col, literal(99.0)),
            lambda data: [np.append(arr, 99.0) for arr in data],
        ),
        (
            lambda col: f.list_push_back(col, literal(99.0)),
            lambda data: [np.append(arr, 99.0) for arr in data],
        ),
        (
            lambda col: f.array_concat(col, col),
            lambda data: [np.concatenate([arr, arr]) for arr in data],
        ),
        (
            lambda col: f.array_cat(col, col),
            lambda data: [np.concatenate([arr, arr]) for arr in data],
        ),
        (
            lambda col: f.list_cat(col, col),
            lambda data: [np.concatenate([arr, arr]) for arr in data],
        ),
        (
            lambda col: f.list_concat(col, col),
            lambda data: [np.concatenate([arr, arr]) for arr in data],
        ),
        (
            lambda col: f.array_dims(col),
            lambda data: [[len(r)] for r in data],
        ),
        (
            lambda col: f.array_distinct(col),
            lambda data: [list(set(r)) for r in data],
        ),
        (
            lambda col: f.list_distinct(col),
            lambda data: [list(set(r)) for r in data],
        ),
        (
            lambda col: f.list_dims(col),
            lambda data: [[len(r)] for r in data],
        ),
        (
            lambda col: f.array_element(col, literal(1)),
            lambda data: [r[0] for r in data],
        ),
        (
            lambda col: f.array_empty(col),
            lambda data: [len(r) == 0 for r in data],
        ),
        (
            lambda col: f.empty(col),
            lambda data: [len(r) == 0 for r in data],
        ),
        (
            lambda col: f.array_extract(col, literal(1)),
            lambda data: [r[0] for r in data],
        ),
        (
            lambda col: f.list_element(col, literal(1)),
            lambda data: [r[0] for r in data],
        ),
        (
            lambda col: f.list_extract(col, literal(1)),
            lambda data: [r[0] for r in data],
        ),
        (
            lambda col: f.array_length(col),
            lambda data: [len(r) for r in data],
        ),
        (
            lambda col: f.list_length(col),
            lambda data: [len(r) for r in data],
        ),
        (
            lambda col: f.array_has(col, literal(1.0)),
            lambda data: [1.0 in r for r in data],
        ),
        (
            lambda col: f.array_has_all(
                col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
            ),
            lambda data: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data],
        ),
        (
            lambda col: f.array_has_any(
                col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
            ),
            lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data],
        ),
        (
            lambda col: f.array_position(col, literal(1.0)),
            lambda data: [py_indexof(r, 1.0) for r in data],
        ),
        (
            lambda col: f.array_indexof(col, literal(1.0)),
            lambda data: [py_indexof(r, 1.0) for r in data],
        ),
        (
            lambda col: f.list_position(col, literal(1.0)),
            lambda data: [py_indexof(r, 1.0) for r in data],
        ),
        (
            lambda col: f.list_indexof(col, literal(1.0)),
            lambda data: [py_indexof(r, 1.0) for r in data],
        ),
        (
            lambda col: f.array_positions(col, literal(1.0)),
            lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data],
        ),
        (
            lambda col: f.list_positions(col, literal(1.0)),
            lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data],
        ),
        (
            lambda col: f.array_ndims(col),
            lambda data: [np.array(r).ndim for r in data],
        ),
        (
            lambda col: f.list_ndims(col),
            lambda data: [np.array(r).ndim for r in data],
        ),
        (
            lambda col: f.array_prepend(literal(99.0), col),
            lambda data: [np.insert(arr, 0, 99.0) for arr in data],
        ),
        (
            lambda col: f.array_push_front(literal(99.0), col),
            lambda data: [np.insert(arr, 0, 99.0) for arr in data],
        ),
        (
            lambda col: f.list_prepend(literal(99.0), col),
            lambda data: [np.insert(arr, 0, 99.0) for arr in data],
        ),
        (
            lambda col: f.list_push_front(literal(99.0), col),
            lambda data: [np.insert(arr, 0, 99.0) for arr in data],
        ),
        (
            lambda col: f.array_pop_back(col),
            lambda data: [arr[:-1] for arr in data],
        ),
        (
            lambda col: f.array_pop_front(col),
            lambda data: [arr[1:] for arr in data],
        ),
        (
            lambda col: f.array_remove(col, literal(3.0)),
            lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data],
        ),
        (
            lambda col: f.list_remove(col, literal(3.0)),
            lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data],
        ),
        (
            lambda col: f.array_remove_n(col, literal(3.0), literal(2)),
            lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data],
        ),
        (
            lambda col: f.list_remove_n(col, literal(3.0), literal(2)),
            lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data],
        ),
        (
            lambda col: f.array_remove_all(col, literal(3.0)),
            lambda data: [py_arr_remove(arr, 3.0) for arr in data],
        ),
        (
            lambda col: f.list_remove_all(col, literal(3.0)),
            lambda data: [py_arr_remove(arr, 3.0) for arr in data],
        ),
        (
            lambda col: f.array_repeat(col, literal(2)),
            lambda data: [[arr] * 2 for arr in data],
        ),
        (
            lambda col: f.list_repeat(col, literal(2)),
            lambda data: [[arr] * 2 for arr in data],
        ),
        (
            lambda col: f.array_replace(col, literal(3.0), literal(4.0)),
            lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
        ),
        (
            lambda col: f.list_replace(col, literal(3.0), literal(4.0)),
            lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
        ),
        (
            lambda col: f.array_replace_n(col, literal(3.0), literal(4.0), literal(1)),
            lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
        ),
        (
            lambda col: f.list_replace_n(col, literal(3.0), literal(4.0), literal(2)),
            lambda data: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data],
        ),
        (
            lambda col: f.array_replace_all(col, literal(3.0), literal(4.0)),
            lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
        ),
        (
            lambda col: f.list_replace_all(col, literal(3.0), literal(4.0)),
            lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
        ),
        (
            lambda col: f.array_sort(col, descending=True, null_first=True),
            lambda data: [np.sort(arr)[::-1] for arr in data],
        ),
        (
            lambda col: f.list_sort(col, descending=False, null_first=False),
            lambda data: [np.sort(arr) for arr in data],
        ),
        (
            lambda col: f.array_slice(col, literal(2), literal(4)),
            lambda data: [arr[1:4] for arr in data],
        ),
        pytest.param(
            lambda col: f.list_slice(col, literal(-1), literal(2)),
            lambda data: [arr[-1:2] for arr in data],
        ),
        (
            lambda col: col[:3],
            lambda data: [arr[:3] for arr in data],
        ),
        (
            lambda col: col[1:3],
            lambda data: [arr[1:3] for arr in data],
        ),
        (
            lambda col: col[1:4:2],
            lambda data: [arr[1:4:2] for arr in data],
        ),
        (
            lambda col: col[literal(1) : literal(4)],
            lambda data: [arr[1:4] for arr in data],
        ),
        (
            lambda col: col[column("indices") : column("indices") + literal(2)],
            lambda data: [[2.0, 3.0], [], [6.0]],
        ),
        (
            lambda col: col[literal(1) : literal(4) : literal(2)],
            lambda data: [arr[1:4:2] for arr in data],
        ),
        (
            lambda col: f.array_intersect(col, literal([3.0, 4.0])),
            lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
        ),
        (
            lambda col: f.list_intersect(col, literal([3.0, 4.0])),
            lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
        ),
        (
            lambda col: f.array_union(col, literal([12.0, 999.0])),
            lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data],
        ),
        (
            lambda col: f.list_union(col, literal([12.0, 999.0])),
            lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data],
        ),
        (
            lambda col: f.array_except(col, literal([3.0])),
            lambda data: [np.setdiff1d(arr, [3.0]) for arr in data],
        ),
        (
            lambda col: f.list_except(col, literal([3.0])),
            lambda data: [np.setdiff1d(arr, [3.0]) for arr in data],
        ),
        (
            lambda col: f.array_resize(col, literal(10), literal(0.0)),
            lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data],
        ),
        (
            lambda col: f.list_resize(col, literal(10), literal(0.0)),
            lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data],
        ),
        (
            lambda col: f.range(literal(1), literal(5), literal(2)),
            lambda data: [np.arange(1, 5, 2)],
        ),
    ],
)
def test_array_functions(stmt, py_expr):
    data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
    indices = [1, 3, 0]
    ctx = SessionContext()
    batch = pa.RecordBatch.from_arrays(
        [np.array(data, dtype=object), indices], names=["arr", "indices"]
    )
    df = ctx.create_dataframe([[batch]])

    col = column("arr")
    query_result = df.select(stmt(col)).collect()[0].column(0)
    for a, b in zip(query_result, py_expr(data), strict=False):
        np.testing.assert_array_almost_equal(
            np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
        )


def test_array_function_flatten():
    data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
    ctx = SessionContext()
    batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"])
    df = ctx.create_dataframe([[batch]])

    stmt = f.flatten(literal(data))
    py_expr = [py_flatten(data)]
    query_result = df.select(stmt).collect()[0].column(0)
    for a, b in zip(query_result, py_expr, strict=False):
        np.testing.assert_array_almost_equal(
            np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
        )


def test_array_function_cardinality():
    data = [[1, 2, 3], [4, 4, 5, 6]]
    ctx = SessionContext()
    batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"])
    df = ctx.create_dataframe([[batch]])

    stmt = f.cardinality(column("arr"))
    py_expr = [len(arr) for arr in data]  # Expected lengths: [3, 3]
    # assert py_expr lengths

    query_result = df.select(stmt).collect()[0].column(0)

    for a, b in zip(query_result, py_expr, strict=False):
        np.testing.assert_array_equal(
            np.array([a.as_py()], dtype=int), np.array([b], dtype=int)
        )


@pytest.mark.parametrize("make_func", [f.make_array, f.make_list])
def test_make_array_functions(make_func):
    ctx = SessionContext()
    batch = pa.RecordBatch.from_arrays(
        [
            pa.array(["Hello", "World", "!"], type=pa.string()),
            pa.array([4, 5, 6]),
            pa.array(["hello ", " world ", " !"], type=pa.string()),
        ],
        names=["a", "b", "c"],
    )
    df = ctx.create_dataframe([[batch]])

    stmt = make_func(
        column("a").cast(pa.string()),
        column("b").cast(pa.string()),
        column("c").cast(pa.string()),
    )
    py_expr = [
        ["Hello", "4", "hello "],
        ["World", "5", " world "],
        ["!", "6", " !"],
    ]

    query_result = df.select(stmt).collect()[0].column(0)
    for a, b in zip(query_result, py_expr, strict=False):
        np.testing.assert_array_equal(
            np.array(a.as_py(), dtype=str), np.array(b, dtype=str)
        )


@pytest.mark.parametrize(
    ("stmt", "py_expr"),
    [
        (
            f.array_to_string(column("arr"), literal(",")),
            lambda data: [",".join([str(int(v)) for v in r]) for r in data],
        ),
        (
            f.array_join(column("arr"), literal(",")),
            lambda data: [",".join([str(int(v)) for v in r]) for r in data],
        ),
        (
            f.list_to_string(column("arr"), literal(",")),
            lambda data: [",".join([str(int(v)) for v in r]) for r in data],
        ),
        (
            f.list_join(column("arr"), literal(",")),
            lambda data: [",".join([str(int(v)) for v in r]) for r in data],
        ),
    ],
)
def test_array_function_obj_tests(stmt, py_expr):
    data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
    ctx = SessionContext()
    batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"])
    df = ctx.create_dataframe([[batch]])
    query_result = np.array(df.select(stmt).collect()[0].column(0))
    for a, b in zip(query_result, py_expr(data), strict=False):
        assert a == b


@pytest.mark.parametrize(
    ("function", "expected_result"),
    [
        (
            f.ascii(column("a")),
            pa.array([72, 87, 33], type=pa.int32()),
        ),  # H = 72; W = 87; ! = 33
        (
            f.bit_length(column("a").cast(pa.string())),
            pa.array([40, 40, 8], type=pa.int32()),
        ),
        (
            f.btrim(literal(" World ")),
            pa.array(["World", "World", "World"], type=pa.string_view()),
        ),
        (f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
        (f.chr(literal(68)), pa.array(["D", "D", "D"])),
        (
            f.concat_ws("-", column("a"), literal("test")),
            pa.array(["Hello-test", "World-test", "!-test"]),
        ),
        (
            f.concat(column("a").cast(pa.string()), literal("?")),
            pa.array(["Hello?", "World?", "!?"], type=pa.string_view()),
        ),
        (
            f.initcap(column("c")),
            pa.array(["Hello ", " World ", " !"], type=pa.string_view()),
        ),
        (f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])),
        (f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())),
        (f.lower(column("a")), pa.array(["hello", "world", "!"])),
        (f.lpad(column("a"), literal(7)), pa.array(["  Hello", "  World", "      !"])),
        (
            f.ltrim(column("c")),
            pa.array(["hello ", "world ", "!"], type=pa.string_view()),
        ),
        (
            f.md5(column("a")),
            pa.array(
                [
                    "8b1a9953c4611296a827abf8c47804d7",
                    "f5a7924e621e84c9280a9a27e1bcb7f6",
                    "9033e0e305f247c0c3c80d0c7848c8b3",
                ],
                type=pa.string_view(),
            ),
        ),
        (f.octet_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
        (
            f.repeat(column("a"), literal(2)),
            pa.array(["HelloHello", "WorldWorld", "!!"]),
        ),
        (
            f.replace(column("a"), literal("l"), literal("?")),
            pa.array(["He??o", "Wor?d", "!"]),
        ),
        (f.reverse(column("a")), pa.array(["olleH", "dlroW", "!"])),
        (f.right(column("a"), literal(4)), pa.array(["ello", "orld", "!"])),
        (
            f.rpad(column("a"), literal(8)),
            pa.array(["Hello   ", "World   ", "!       "]),
        ),
        (
            f.rtrim(column("c")),
            pa.array(["hello", " world", " !"], type=pa.string_view()),
        ),
        (
            f.split_part(column("a"), literal("l"), literal(1)),
            pa.array(["He", "Wor", "!"]),
        ),
        (f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])),
        (f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())),
        (
            f.substr(column("a"), literal(3)),
            pa.array(["llo", "rld", ""], type=pa.string_view()),
        ),
        (
            f.translate(column("a"), literal("or"), literal("ld")),
            pa.array(["Helll", "Wldld", "!"]),
        ),
        (f.trim(column("c")), pa.array(["hello", "world", "!"], type=pa.string_view())),
        (f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])),
        (f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])),
        (
            f.overlay(column("a"), literal("--"), literal(2)),
            pa.array(["H--lo", "W--ld", "--"]),
        ),
        (
            f.regexp_like(column("a"), literal("(ell|orl)")),
            pa.array([True, True, False]),
        ),
        (
            f.regexp_match(column("a"), literal("(ell|orl)")),
            pa.array([["ell"], ["orl"], None], type=pa.list_(pa.string_view())),
        ),
        (
            f.regexp_replace(column("a"), literal("(ell|orl)"), literal("-")),
            pa.array(["H-o", "W-d", "!"]),
        ),
        (
            f.regexp_count(column("a"), literal("(ell|orl)"), literal(1)),
            pa.array([1, 1, 0], type=pa.int64()),
        ),
    ],
)
def test_string_functions(df, function, expected_result):
    df = df.select(function)
    result = df.collect()
    assert len(result) == 1
    result = result[0]
    assert result.column(0) == expected_result


def test_hash_functions(df):
    exprs = [
        f.digest(column("a"), literal(m))
        for m in (
            "md5",
            "sha224",
            "sha256",
            "sha384",
            "sha512",
            "blake2s",
            "blake3",
        )
    ]
    df = df.select(
        *exprs,
        f.sha224(column("a")),
        f.sha256(column("a")),
        f.sha384(column("a")),
        f.sha512(column("a")),
    )
    result = df.collect()
    assert len(result) == 1
    result = result[0]
    b = bytearray.fromhex
    assert result.column(0) == pa.array(
        [
            b("8B1A9953C4611296A827ABF8C47804D7"),
            b("F5A7924E621E84C9280A9A27E1BCB7F6"),
            b("9033E0E305F247C0C3C80D0C7848C8B3"),
        ]
    )
    assert result.column(1) == pa.array(
        [
            b("4149DA18AA8BFC2B1E382C6C26556D01A92C261B6436DAD5E3BE3FCC"),
            b("12972632B6D3B6AA52BD6434552F08C1303D56B817119406466E9236"),
            b("6641A7E8278BCD49E476E7ACAE158F4105B2952D22AEB2E0B9A231A0"),
        ]
    )
    assert result.column(2) == pa.array(
        [
            b("185F8DB32271FE25F561A6FC938B2E264306EC304EDA518007D1764826381969"),
            b("78AE647DC5544D227130A0682A51E30BC7777FBB6D8A8F17007463A3ECD1D524"),
            b("BB7208BC9B5D7C04F1236A82A0093A5E33F40423D5BA8D4266F7092C3BA43B62"),
        ]
    )
    assert result.column(3) == pa.array(
        [
            b(
                "3519FE5AD2C596EFE3E276A6F351B8FC"
                "0B03DB861782490D45F7598EBD0AB5FD"
                "5520ED102F38C4A5EC834E98668035FC"
            ),
            b(
                "ED7CED84875773603AF90402E42C65F3"
                "B48A5E77F84ADC7A19E8F3E8D3101010"
                "22F552AEC70E9E1087B225930C1D260A"
            ),
            b(
                "1D0EC8C84EE9521E21F06774DE232367"
                "B64DE628474CB5B2E372B699A1F55AE3"
                "35CC37193EF823E33324DFD9A70738A6"
            ),
        ]
    )
    assert result.column(4) == pa.array(
        [
            b(
                "3615F80C9D293ED7402687F94B22D58E"
                "529B8CC7916F8FAC7FDDF7FBD5AF4CF7"
                "77D3D795A7A00A16BF7E7F3FB9561EE9"
                "BAAE480DA9FE7A18769E71886B03F315"
            ),
            b(
                "8EA77393A42AB8FA92500FB077A9509C"
                "C32BC95E72712EFA116EDAF2EDFAE34F"
                "BB682EFDD6C5DD13C117E08BD4AAEF71"
                "291D8AACE2F890273081D0677C16DF0F"
            ),
            b(
                "3831A6A6155E509DEE59A7F451EB3532"
                "4D8F8F2DF6E3708894740F98FDEE2388"
                "9F4DE5ADB0C5010DFB555CDA77C8AB5D"
                "C902094C52DE3278F35A75EBC25F093A"
            ),
        ]
    )
    assert result.column(5) == pa.array(
        [
            b("F73A5FBF881F89B814871F46E26AD3FA37CB2921C5E8561618639015B3CCBB71"),
            b("B792A0383FB9E7A189EC150686579532854E44B71AC394831DAED169BA85CCC5"),
            b("27988A0E51812297C77A433F635233346AEE29A829DCF4F46E0F58F402C6CFCB"),
        ]
    )
    assert result.column(6) == pa.array(
        [
            b("FBC2B0516EE8744D293B980779178A3508850FDCFE965985782C39601B65794F"),
            b("BF73D18575A736E4037D45F9E316085B86C19BE6363DE6AA789E13DEAACC1C4E"),
            b("C8D11B9F7237E4034ADBCD2005735F9BC4C597C75AD89F4492BEC8F77D15F7EB"),
        ]
    )
    assert result.column(7) == result.column(1)  # SHA-224
    assert result.column(8) == result.column(2)  # SHA-256
    assert result.column(9) == result.column(3)  # SHA-384
    assert result.column(10) == result.column(4)  # SHA-512


def test_temporal_functions(df):
    df = df.select(
        f.date_part(literal("month"), column("d")),
        f.datepart(literal("year"), column("d")),
        f.date_trunc(literal("month"), column("d")),
        f.datetrunc(literal("day"), column("d")),
        f.date_bin(
            literal("15 minutes").cast(pa.string()),
            column("d"),
            literal("2001-01-01 00:02:30").cast(pa.string()),
        ),
        f.from_unixtime(literal(1673383974)),
        f.to_timestamp(literal("2023-09-07 05:06:14.523952")),
        f.to_timestamp_seconds(literal("2023-09-07 05:06:14.523952")),
        f.to_timestamp_millis(literal("2023-09-07 05:06:14.523952")),
        f.to_timestamp_micros(literal("2023-09-07 05:06:14.523952")),
        f.extract(literal("day"), column("d")),
        f.to_timestamp(
            literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
        ),
        f.to_timestamp_seconds(
            literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
        ),
        f.to_timestamp_millis(
            literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
        ),
        f.to_timestamp_micros(
            literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
        ),
        f.to_timestamp_nanos(literal("2023-09-07 05:06:14.523952")),
        f.to_timestamp_nanos(
            literal("2023-09-07 05:06:14.523952000"), literal("%Y-%m-%d %H:%M:%S.%f")
        ),
    )
    result = df.collect()
    assert len(result) == 1
    result = result[0]
    assert result.column(0) == pa.array([12, 6, 7], type=pa.int32())
    assert result.column(1) == pa.array([2022, 2027, 2020], type=pa.int32())
    assert result.column(2) == pa.array(
        [
            datetime(2022, 12, 1, tzinfo=DEFAULT_TZ),
            datetime(2027, 6, 1, tzinfo=DEFAULT_TZ),
            datetime(2020, 7, 1, tzinfo=DEFAULT_TZ),
        ],
        type=pa.timestamp("ns", tz=DEFAULT_TZ),
    )
    assert result.column(3) == pa.array(
        [
            datetime(2022, 12, 31, tzinfo=DEFAULT_TZ),
            datetime(2027, 6, 26, tzinfo=DEFAULT_TZ),
            datetime(2020, 7, 2, tzinfo=DEFAULT_TZ),
        ],
        type=pa.timestamp("ns", tz=DEFAULT_TZ),
    )
    assert result.column(4) == pa.array(
        [
            datetime(2022, 12, 30, 23, 47, 30, tzinfo=DEFAULT_TZ),
            datetime(2027, 6, 25, 23, 47, 30, tzinfo=DEFAULT_TZ),
            datetime(2020, 7, 1, 23, 47, 30, tzinfo=DEFAULT_TZ),
        ],
        type=pa.timestamp("ns", tz=DEFAULT_TZ),
    )
    assert result.column(5) == pa.array(
        [datetime(2023, 1, 10, 20, 52, 54, tzinfo=DEFAULT_TZ)] * 3,
        type=pa.timestamp("s"),
    )
    assert result.column(6) == pa.array(
        [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
        type=pa.timestamp("ns"),
    )
    assert result.column(7) == pa.array(
        [datetime(2023, 9, 7, 5, 6, 14, tzinfo=DEFAULT_TZ)] * 3, type=pa.timestamp("s")
    )
    assert result.column(8) == pa.array(
        [datetime(2023, 9, 7, 5, 6, 14, 523000, tzinfo=DEFAULT_TZ)] * 3,
        type=pa.timestamp("ms"),
    )
    assert result.column(9) == pa.array(
        [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
        type=pa.timestamp("us"),
    )
    assert result.column(10) == pa.array([31, 26, 2], type=pa.int32())
    assert result.column(11) == pa.array(
        [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
        type=pa.timestamp("ns"),
    )
    assert result.column(12) == pa.array(
        [datetime(2023, 9, 7, 5, 6, 14, tzinfo=DEFAULT_TZ)] * 3,
        type=pa.timestamp("s"),
    )
    assert result.column(13) == pa.array(
        [datetime(2023, 9, 7, 5, 6, 14, 523000, tzinfo=DEFAULT_TZ)] * 3,
        type=pa.timestamp("ms"),
    )
    assert result.column(14) == pa.array(
        [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
        type=pa.timestamp("us"),
    )
    assert result.column(15) == pa.array(
        [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
        type=pa.timestamp("ns"),
    )
    assert result.column(16) == pa.array(
        [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3,
        type=pa.timestamp("ns"),
    )


def test_arrow_cast(df):
    df = df.select(
        # we use `string_literal` to return utf8 instead of `literal` which returns
        # utf8view because datafusion.arrow_cast expects a utf8 instead of utf8view
        # https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179
        f.arrow_cast(column("b"), string_literal("Float64")).alias("b_as_float"),
        f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"),
    )
    result = df.collect()
    assert len(result) == 1
    result = result[0]

    assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
    assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())


def test_case(df):
    df = df.select(
        f.case(column("b")).when(literal(4), literal(10)).otherwise(literal(8)),
        f.case(column("a"))
        .when(literal("Hello"), literal("Hola"))
        .when(literal("World"), literal("Mundo"))
        .otherwise(literal("!!")),
        f.case(column("a"))
        .when(literal("Hello"), literal("Hola"))
        .when(literal("World"), literal("Mundo"))
        .end(),
    )

    result = df.collect()
    result = result[0]
    assert result.column(0) == pa.array([10, 8, 8])
    assert result.column(1) == pa.array(["Hola", "Mundo", "!!"], type=pa.string_view())
    assert result.column(2) == pa.array(["Hola", "Mundo", None], type=pa.string_view())


def test_when_with_no_base(df):
    df.show()
    df = df.select(
        column("b"),
        f.when(column("b") > literal(5), literal("too big"))
        .when(column("b") < literal(5), literal("too small"))
        .otherwise(literal("just right"))
        .alias("goldilocks"),
        f.when(column("a") == literal("Hello"), column("a")).end().alias("greeting"),
    )
    df.show()

    result = df.collect()
    result = result[0]
    assert result.column(0) == pa.array([4, 5, 6])
    assert result.column(1) == pa.array(
        ["too small", "just right", "too big"], type=pa.string_view()
    )
    assert result.column(2) == pa.array(["Hello", None, None], type=pa.string_view())


def test_regr_funcs_sql(df):
    # test case base on
    # https://github.com/apache/arrow-datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2330
    ctx = SessionContext()
    result = ctx.sql(
        "select regr_slope(1,1), regr_intercept(1,1), "
        "regr_count(1,1), regr_r2(1,1), regr_avgx(1,1), "
        "regr_avgy(1,1), regr_sxx(1,1), regr_syy(1,1), "
        "regr_sxy(1,1);"
    ).collect()

    assert result[0].column(0) == pa.array([None], type=pa.float64())
    assert result[0].column(1) == pa.array([None], type=pa.float64())
    assert result[0].column(2) == pa.array([1], type=pa.uint64())
    assert result[0].column(3) == pa.array([None], type=pa.float64())
    assert result[0].column(4) == pa.array([1], type=pa.float64())
    assert result[0].column(5) == pa.array([1], type=pa.float64())
    assert result[0].column(6) == pa.array([0], type=pa.float64())
    assert result[0].column(7) == pa.array([0], type=pa.float64())
    assert result[0].column(8) == pa.array([0], type=pa.float64())


def test_regr_funcs_sql_2():
    # test case based on `regr_*() basic tests
    # https://github.com/apache/datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2358C1-L2374C1
    ctx = SessionContext()

    # Perform the regression functions using SQL
    result_sql = ctx.sql(
        "select "
        "regr_slope(column2, column1), "
        "regr_intercept(column2, column1), "
        "regr_count(column2, column1), "
        "regr_r2(column2, column1), "
        "regr_avgx(column2, column1), "
        "regr_avgy(column2, column1), "
        "regr_sxx(column2, column1), "
        "regr_syy(column2, column1), "
        "regr_sxy(column2, column1) "
        "from (values (1,2), (2,4), (3,6))"
    ).collect()

    # Assertions for SQL results
    assert result_sql[0].column(0) == pa.array([2], type=pa.float64())
    assert result_sql[0].column(1) == pa.array([0], type=pa.float64())
    assert result_sql[0].column(2) == pa.array([3], type=pa.uint64())
    assert result_sql[0].column(3) == pa.array([1], type=pa.float64())
    assert result_sql[0].column(4) == pa.array([2], type=pa.float64())
    assert result_sql[0].column(5) == pa.array([4], type=pa.float64())
    assert result_sql[0].column(6) == pa.array([2], type=pa.float64())
    assert result_sql[0].column(7) == pa.array([8], type=pa.float64())
    assert result_sql[0].column(8) == pa.array([4], type=pa.float64())


@pytest.mark.parametrize(
    ("func", "expected"),
    [
        pytest.param(f.regr_slope(column("c2"), column("c1")), [4.6], id="regr_slope"),
        pytest.param(
            f.regr_slope(column("c2"), column("c1"), filter=column("c1") > literal(2)),
            [8],
            id="regr_slope_filter",
        ),
        pytest.param(
            f.regr_intercept(column("c2"), column("c1")), [-4], id="regr_intercept"
        ),
        pytest.param(
            f.regr_intercept(
                column("c2"), column("c1"), filter=column("c1") > literal(2)
            ),
            [-16],
            id="regr_intercept_filter",
        ),
        pytest.param(f.regr_count(column("c2"), column("c1")), [4], id="regr_count"),
        pytest.param(
            f.regr_count(column("c2"), column("c1"), filter=column("c1") > literal(2)),
            [2],
            id="regr_count_filter",
        ),
        pytest.param(f.regr_r2(column("c2"), column("c1")), [0.92], id="regr_r2"),
        pytest.param(
            f.regr_r2(column("c2"), column("c1"), filter=column("c1") > literal(2)),
            [1.0],
            id="regr_r2_filter",
        ),
        pytest.param(f.regr_avgx(column("c2"), column("c1")), [2.5], id="regr_avgx"),
        pytest.param(
            f.regr_avgx(column("c2"), column("c1"), filter=column("c1") > literal(2)),
            [3.5],
            id="regr_avgx_filter",
        ),
        pytest.param(f.regr_avgy(column("c2"), column("c1")), [7.5], id="regr_avgy"),
        pytest.param(
            f.regr_avgy(column("c2"), column("c1"), filter=column("c1") > literal(2)),
            [12.0],
            id="regr_avgy_filter",
        ),
        pytest.param(f.regr_sxx(column("c2"), column("c1")), [5.0], id="regr_sxx"),
        pytest.param(
            f.regr_sxx(column("c2"), column("c1"), filter=column("c1") > literal(2)),
            [0.5],
            id="regr_sxx_filter",
        ),
        pytest.param(f.regr_syy(column("c2"), column("c1")), [115.0], id="regr_syy"),
        pytest.param(
            f.regr_syy(column("c2"), column("c1"), filter=column("c1") > literal(2)),
            [32.0],
            id="regr_syy_filter",
        ),
        pytest.param(f.regr_sxy(column("c2"), column("c1")), [23.0], id="regr_sxy"),
        pytest.param(
            f.regr_sxy(column("c2"), column("c1"), filter=column("c1") > literal(2)),
            [4.0],
            id="regr_sxy_filter",
        ),
    ],
)
def test_regr_funcs_df(func, expected):
    # test case based on `regr_*() basic tests
    # https://github.com/apache/datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2358C1-L2374C1

    ctx = SessionContext()

    # Create a DataFrame
    data = {"c1": [1, 2, 3, 4, 5, None], "c2": [2, 4, 8, 16, None, 64]}
    df = ctx.from_pydict(data, name="test_table")

    # Perform the regression function using DataFrame API
    df = df.aggregate([], [func.alias("result")])
    df.show()

    expected_dict = {
        "result": expected,
    }

    assert df.collect()[0].to_pydict() == expected_dict


def test_binary_string_functions(df):
    df = df.select(
        f.encode(column("a").cast(pa.string()), literal("base64").cast(pa.string())),
        f.decode(
            f.encode(
                column("a").cast(pa.string()), literal("base64").cast(pa.string())
            ),
            literal("base64").cast(pa.string()),
        ),
    )
    result = df.collect()
    assert len(result) == 1
    result = result[0]
    assert result.column(0) == pa.array(["SGVsbG8", "V29ybGQ", "IQ"])
    assert pa.array(result.column(1)).cast(pa.string()) == pa.array(
        ["Hello", "World", "!"]
    )


@pytest.mark.parametrize(
    ("python_datatype", "name", "expected"),
    [
        pytest.param(bool, "e", pa.bool_(), id="bool"),
        pytest.param(int, "b", pa.int64(), id="int"),
        pytest.param(float, "b", pa.float64(), id="float"),
        pytest.param(str, "b", pa.string(), id="str"),
    ],
)
def test_cast(df, python_datatype, name: str, expected):
    df = df.select(
        column(name).cast(python_datatype).alias("actual"),
        column(name).cast(expected).alias("expected"),
    )
    result = df.collect()
    result = result[0]
    assert result.column(0) == result.column(1)


@pytest.mark.parametrize(
    ("negated", "low", "high", "expected"),
    [
        pytest.param(False, 3, 5, {"filtered": [4, 5]}),
        pytest.param(False, 4, 5, {"filtered": [4, 5]}),
        pytest.param(True, 3, 5, {"filtered": [6]}),
        pytest.param(True, 4, 6, []),
    ],
)
def test_between(df, negated, low, high, expected):
    df = df.filter(column("b").between(low, high, negated=negated)).select(
        column("b").alias("filtered")
    )

    actual = df.collect()

    if expected:
        actual = actual[0].to_pydict()
        assert actual == expected
    else:
        assert len(actual) == 0  # the rows are empty


def test_between_default(df):
    df = df.filter(column("b").between(3, 5)).select(column("b").alias("filtered"))
    expected = {"filtered": [4, 5]}

    actual = df.collect()[0].to_pydict()
    assert actual == expected


def test_alias_with_metadata(df):
    df = df.select(f.alias(f.col("a"), "b", {"key": "value"}))
    assert df.schema().field("b").metadata == {b"key": b"value"}


def test_coalesce(df):
    # Create a DataFrame with null values
    ctx = SessionContext()
    batch = pa.RecordBatch.from_arrays(
        [
            pa.array(["Hello", None, "!"]),  # string column with null
            pa.array([4, None, 6]),  # integer column with null
            pa.array(["hello ", None, " !"]),  # string column with null
            pa.array(
                [
                    datetime(2022, 12, 31, tzinfo=DEFAULT_TZ),
                    None,
                    datetime(2020, 7, 2, tzinfo=DEFAULT_TZ),
                ]
            ),  # datetime with null
            pa.array([False, None, True]),  # boolean column with null
        ],
        names=["a", "b", "c", "d", "e"],
    )
    df_with_nulls = ctx.create_dataframe([[batch]])

    # Test coalesce with different data types
    result_df = df_with_nulls.select(
        f.coalesce(column("a"), literal("default")).alias("a_coalesced"),
        f.coalesce(column("b"), literal(0)).alias("b_coalesced"),
        f.coalesce(column("c"), literal("default")).alias("c_coalesced"),
        f.coalesce(column("d"), literal(datetime(2000, 1, 1, tzinfo=DEFAULT_TZ))).alias(
            "d_coalesced"
        ),
        f.coalesce(column("e"), literal(value=False)).alias("e_coalesced"),
    )

    result = result_df.collect()[0]

    # Verify results
    assert result.column(0) == pa.array(
        ["Hello", "default", "!"], type=pa.string_view()
    )
    assert result.column(1) == pa.array([4, 0, 6], type=pa.int64())
    assert result.column(2) == pa.array(
        ["hello ", "default", " !"], type=pa.string_view()
    )
    assert result.column(3).to_pylist() == [
        datetime(2022, 12, 31, tzinfo=DEFAULT_TZ),
        datetime(2000, 1, 1, tzinfo=DEFAULT_TZ),
        datetime(2020, 7, 2, tzinfo=DEFAULT_TZ),
    ]
    assert result.column(4) == pa.array([False, False, True], type=pa.bool_())

    # Test multiple arguments
    result_df = df_with_nulls.select(
        f.coalesce(column("a"), literal(None), literal("fallback")).alias(
            "multi_coalesce"
        )
    )
    result = result_df.collect()[0]
    assert result.column(0) == pa.array(
        ["Hello", "fallback", "!"], type=pa.string_view()
    )
