| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import re |
| from concurrent.futures import ThreadPoolExecutor |
| from datetime import datetime, timezone |
| |
| import pyarrow as pa |
| import pytest |
| from datafusion import ( |
| SessionContext, |
| col, |
| functions, |
| lit, |
| lit_with_metadata, |
| literal_with_metadata, |
| ) |
| from datafusion.expr import ( |
| EXPR_TYPE_ERROR, |
| Aggregate, |
| AggregateFunction, |
| BinaryExpr, |
| Column, |
| CopyTo, |
| CreateIndex, |
| DescribeTable, |
| DmlStatement, |
| DropCatalogSchema, |
| Filter, |
| Limit, |
| Literal, |
| Projection, |
| RecursiveQuery, |
| Sort, |
| TableScan, |
| TransactionEnd, |
| TransactionStart, |
| Values, |
| ensure_expr, |
| ensure_expr_list, |
| ) |
| |
| |
| @pytest.fixture |
| def test_ctx(): |
| ctx = SessionContext() |
| ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv") |
| return ctx |
| |
| |
| def test_projection(test_ctx): |
| df = test_ctx.sql("select c1, 123, c1 < 123 from test") |
| plan = df.logical_plan() |
| |
| plan = plan.to_variant() |
| assert isinstance(plan, Projection) |
| |
| expr = plan.projections() |
| |
| col1 = expr[0].to_variant() |
| assert isinstance(col1, Column) |
| assert col1.name() == "c1" |
| assert col1.qualified_name() == "test.c1" |
| |
| col2 = expr[1].to_variant() |
| assert isinstance(col2, Literal) |
| assert col2.data_type() == "Int64" |
| assert col2.value_i64() == 123 |
| |
| col3 = expr[2].to_variant() |
| assert isinstance(col3, BinaryExpr) |
| assert isinstance(col3.left().to_variant(), Column) |
| assert col3.op() == "<" |
| assert isinstance(col3.right().to_variant(), Literal) |
| |
| plan = plan.input()[0].to_variant() |
| assert isinstance(plan, TableScan) |
| |
| |
| def test_filter(test_ctx): |
| df = test_ctx.sql("select c1 from test WHERE c1 > 5") |
| plan = df.logical_plan() |
| |
| plan = plan.to_variant() |
| assert isinstance(plan, Projection) |
| |
| plan = plan.input()[0].to_variant() |
| assert isinstance(plan, Filter) |
| |
| |
| def test_limit(test_ctx): |
| df = test_ctx.sql("select c1 from test LIMIT 10") |
| plan = df.logical_plan() |
| |
| plan = plan.to_variant() |
| assert isinstance(plan, Limit) |
| assert "Skip: None" in str(plan) |
| |
| df = test_ctx.sql("select c1 from test LIMIT 10 OFFSET 5") |
| plan = df.logical_plan() |
| |
| plan = plan.to_variant() |
| assert isinstance(plan, Limit) |
| assert "Skip: Some(Literal(Int64(5), None))" in str(plan) |
| |
| |
| def test_aggregate_query(test_ctx): |
| df = test_ctx.sql("select c1, count(*) from test group by c1") |
| plan = df.logical_plan() |
| |
| projection = plan.to_variant() |
| assert isinstance(projection, Projection) |
| |
| aggregate = projection.input()[0].to_variant() |
| assert isinstance(aggregate, Aggregate) |
| |
| col1 = aggregate.group_by_exprs()[0].to_variant() |
| assert isinstance(col1, Column) |
| assert col1.name() == "c1" |
| assert col1.qualified_name() == "test.c1" |
| |
| col2 = aggregate.aggregate_exprs()[0].to_variant() |
| assert isinstance(col2, AggregateFunction) |
| |
| |
| def test_sort(test_ctx): |
| df = test_ctx.sql("select c1 from test order by c1") |
| plan = df.logical_plan() |
| |
| plan = plan.to_variant() |
| assert isinstance(plan, Sort) |
| |
| |
| def test_relational_expr(test_ctx): |
| ctx = SessionContext() |
| |
| batch = pa.RecordBatch.from_arrays( |
| [ |
| pa.array([1, 2, 3]), |
| pa.array(["alpha", "beta", "gamma"], type=pa.string_view()), |
| ], |
| names=["a", "b"], |
| ) |
| df = ctx.create_dataframe([[batch]], name="batch_array") |
| |
| assert df.filter(col("a") == 1).count() == 1 |
| assert df.filter(col("a") != 1).count() == 2 |
| assert df.filter(col("a") >= 1).count() == 3 |
| assert df.filter(col("a") > 1).count() == 2 |
| assert df.filter(col("a") <= 3).count() == 3 |
| assert df.filter(col("a") < 3).count() == 2 |
| |
| assert df.filter(col("b") == "beta").count() == 1 |
| assert df.filter(col("b") != "beta").count() == 2 |
| |
| assert df.filter(col("a") == "beta").count() == 0 |
| |
| |
| def test_expr_to_variant(): |
| # Taken from https://github.com/apache/datafusion-python/issues/781 |
| from datafusion import SessionContext |
| from datafusion.expr import Filter |
| |
| def traverse_logical_plan(plan): |
| cur_node = plan.to_variant() |
| if isinstance(cur_node, Filter): |
| return cur_node.predicate().to_variant() |
| if hasattr(plan, "inputs"): |
| for input_plan in plan.inputs(): |
| res = traverse_logical_plan(input_plan) |
| if res is not None: |
| return res |
| return None |
| |
| ctx = SessionContext() |
| data = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"]} |
| ctx.from_pydict(data, name="table1") |
| query = "SELECT * FROM table1 t1 WHERE t1.name IN ('dfa', 'ad', 'dfre', 'vsa')" |
| logical_plan = ctx.sql(query).optimized_logical_plan() |
| variant = traverse_logical_plan(logical_plan) |
| assert variant is not None |
| assert variant.expr().to_variant().qualified_name() == "table1.name" |
| assert ( |
| str(variant.list()) |
| == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]' # noqa: E501 |
| ) |
| assert not variant.negated() |
| |
| |
| def test_case_builder_error_preserves_builder_state(): |
| case_builder = functions.when(lit(True), lit(1)) |
| |
| with pytest.raises(Exception) as exc_info: |
| _ = case_builder.otherwise(lit("bad")) |
| |
| err_msg = str(exc_info.value) |
| assert "multiple data types" in err_msg |
| assert "CaseBuilder has already been consumed" not in err_msg |
| |
| _ = case_builder.end() |
| |
| err_msg = str(exc_info.value) |
| assert "multiple data types" in err_msg |
| assert "CaseBuilder has already been consumed" not in err_msg |
| |
| |
| def test_case_builder_success_preserves_builder_state(): |
| ctx = SessionContext() |
| df = ctx.from_pydict({"flag": [False]}, name="tbl") |
| |
| case_builder = functions.when(col("flag"), lit("true")) |
| |
| expr_default_one = case_builder.otherwise(lit("default-1")).alias("result") |
| result_one = df.select(expr_default_one).collect() |
| assert result_one[0].column(0).to_pylist() == ["default-1"] |
| |
| expr_default_two = case_builder.otherwise(lit("default-2")).alias("result") |
| result_two = df.select(expr_default_two).collect() |
| assert result_two[0].column(0).to_pylist() == ["default-2"] |
| |
| expr_end_one = case_builder.end().alias("result") |
| end_one = df.select(expr_end_one).collect() |
| assert end_one[0].column(0).to_pylist() == [None] |
| |
| |
| def test_case_builder_when_handles_are_independent(): |
| ctx = SessionContext() |
| df = ctx.from_pydict( |
| { |
| "flag": [True, False, False, False], |
| "value": [1, 15, 25, 5], |
| }, |
| name="tbl", |
| ) |
| |
| base_builder = functions.when(col("flag"), lit("flag-true")) |
| |
| first_builder = base_builder.when(col("value") > lit(10), lit("gt10")) |
| second_builder = base_builder.when(col("value") > lit(20), lit("gt20")) |
| |
| first_builder = first_builder.when(lit(True), lit("final-one")) |
| |
| expr_first = first_builder.otherwise(lit("fallback-one")).alias("first") |
| expr_second = second_builder.otherwise(lit("fallback-two")).alias("second") |
| |
| result = df.select(expr_first, expr_second).collect()[0] |
| |
| assert result.column(0).to_pylist() == [ |
| "flag-true", |
| "gt10", |
| "gt10", |
| "final-one", |
| ] |
| assert result.column(1).to_pylist() == [ |
| "flag-true", |
| "fallback-two", |
| "gt20", |
| "fallback-two", |
| ] |
| |
| |
| def test_case_builder_when_thread_safe(): |
| case_builder = functions.when(lit(True), lit(1)) |
| |
| def build_expr(value: int) -> bool: |
| builder = case_builder.when(lit(True), lit(value)) |
| builder.otherwise(lit(value)) |
| return True |
| |
| with ThreadPoolExecutor(max_workers=8) as executor: |
| futures = [executor.submit(build_expr, idx) for idx in range(16)] |
| results = [future.result() for future in futures] |
| |
| assert all(results) |
| |
| # Ensure the shared builder remains usable after concurrent `when` calls. |
| follow_up_builder = case_builder.when(lit(True), lit(42)) |
| assert isinstance(follow_up_builder, type(case_builder)) |
| follow_up_builder.otherwise(lit(7)) |
| |
| |
| def test_expr_getitem() -> None: |
| ctx = SessionContext() |
| data = { |
| "array_values": [[1, 2, 3], [4, 5], [6], []], |
| "struct_values": [ |
| {"name": "Alice", "age": 15}, |
| {"name": "Bob", "age": 14}, |
| {"name": "Charlie", "age": 13}, |
| {"name": None, "age": 12}, |
| ], |
| } |
| df = ctx.from_pydict(data, name="table1") |
| |
| names = df.select(col("struct_values")["name"].alias("name")).collect() |
| names = [r.as_py() for rs in names for r in rs["name"]] |
| |
| array_values = df.select(col("array_values")[1].alias("value")).collect() |
| array_values = [r.as_py() for rs in array_values for r in rs["value"]] |
| |
| assert names == ["Alice", "Bob", "Charlie", None] |
| assert array_values == [2, 5, None, None] |
| |
| |
| def test_display_name_deprecation(): |
| import warnings |
| |
| expr = col("foo") |
| with warnings.catch_warnings(record=True) as w: |
| # Cause all warnings to always be triggered |
| warnings.simplefilter("always") |
| |
| # should trigger warning |
| name = expr.display_name() |
| |
| # Verify some things |
| assert len(w) == 1 |
| assert issubclass(w[-1].category, DeprecationWarning) |
| assert "deprecated" in str(w[-1].message) |
| |
| # returns appropriate result |
| assert name == expr.schema_name() |
| assert name == "foo" |
| |
| |
| @pytest.fixture |
| def df(): |
| ctx = SessionContext() |
| |
| # create a RecordBatch and a new DataFrame from it |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, None]), pa.array([4, None, 6]), pa.array([None, None, 8])], |
| names=["a", "b", "c"], |
| ) |
| |
| return ctx.from_arrow(batch) |
| |
| |
| def test_fill_null(df): |
| df = df.select( |
| col("a").fill_null(100).alias("a"), |
| col("b").fill_null(25).alias("b"), |
| col("c").fill_null(1234).alias("c"), |
| ) |
| df.show() |
| result = df.collect()[0] |
| |
| assert result.column(0) == pa.array([1, 2, 100]) |
| assert result.column(1) == pa.array([4, 25, 6]) |
| assert result.column(2) == pa.array([1234, 1234, 8]) |
| |
| |
| def test_copy_to(): |
| ctx = SessionContext() |
| ctx.sql("CREATE TABLE foo (a int, b int)").collect() |
| df = ctx.sql("COPY foo TO bar STORED AS CSV") |
| plan = df.logical_plan() |
| plan = plan.to_variant() |
| assert isinstance(plan, CopyTo) |
| |
| |
| def test_create_index(): |
| ctx = SessionContext() |
| ctx.sql("CREATE TABLE foo (a int, b int)").collect() |
| plan = ctx.sql("create index idx on foo (a)").logical_plan() |
| plan = plan.to_variant() |
| assert isinstance(plan, CreateIndex) |
| |
| |
| def test_describe_table(): |
| ctx = SessionContext() |
| ctx.sql("CREATE TABLE foo (a int, b int)").collect() |
| plan = ctx.sql("describe foo").logical_plan() |
| plan = plan.to_variant() |
| assert isinstance(plan, DescribeTable) |
| |
| |
| def test_dml_statement(): |
| ctx = SessionContext() |
| ctx.sql("CREATE TABLE foo (a int, b int)").collect() |
| plan = ctx.sql("insert into foo values (1, 2)").logical_plan() |
| plan = plan.to_variant() |
| assert isinstance(plan, DmlStatement) |
| |
| |
| def drop_catalog_schema(): |
| ctx = SessionContext() |
| plan = ctx.sql("drop schema cat").logical_plan() |
| plan = plan.to_variant() |
| assert isinstance(plan, DropCatalogSchema) |
| |
| |
| def test_recursive_query(): |
| ctx = SessionContext() |
| plan = ctx.sql( |
| """ |
| WITH RECURSIVE cte AS ( |
| SELECT 1 as n |
| UNION ALL |
| SELECT n + 1 FROM cte WHERE n < 5 |
| ) |
| SELECT * FROM cte; |
| """ |
| ).logical_plan() |
| plan = plan.inputs()[0].inputs()[0].to_variant() |
| assert isinstance(plan, RecursiveQuery) |
| |
| |
| def test_values(): |
| ctx = SessionContext() |
| plan = ctx.sql("values (1, 'foo'), (2, 'bar')").logical_plan() |
| plan = plan.to_variant() |
| assert isinstance(plan, Values) |
| |
| |
| def test_transaction_start(): |
| ctx = SessionContext() |
| plan = ctx.sql("START TRANSACTION").logical_plan() |
| plan = plan.to_variant() |
| assert isinstance(plan, TransactionStart) |
| |
| |
| def test_transaction_end(): |
| ctx = SessionContext() |
| plan = ctx.sql("COMMIT").logical_plan() |
| plan = plan.to_variant() |
| assert isinstance(plan, TransactionEnd) |
| |
| |
| def test_col_getattr(): |
| ctx = SessionContext() |
| data = { |
| "array_values": [[1, 2, 3], [4, 5], [6], []], |
| "struct_values": [ |
| {"name": "Alice", "age": 15}, |
| {"name": "Bob", "age": 14}, |
| {"name": "Charlie", "age": 13}, |
| {"name": None, "age": 12}, |
| ], |
| } |
| df = ctx.from_pydict(data, name="table1") |
| |
| names = df.select(col.struct_values["name"].alias("name")).collect() |
| names = [r.as_py() for rs in names for r in rs["name"]] |
| |
| array_values = df.select(col.array_values[1].alias("value")).collect() |
| array_values = [r.as_py() for rs in array_values for r in rs["value"]] |
| |
| assert names == ["Alice", "Bob", "Charlie", None] |
| assert array_values == [2, 5, None, None] |
| |
| |
| def test_alias_with_metadata(df): |
| df = df.select(col("a").alias("b", {"key": "value"})) |
| assert df.schema().field("b").metadata == {b"key": b"value"} |
| |
| |
| # These unit tests are to ensure the expression functions do not regress |
| # For the math functions we will use `functions.round` so we can more |
| # easily test for equivalence and not worry about floating point precision |
| @pytest.mark.parametrize( |
| ("function", "expected_result"), |
| [ |
| # Math Functions |
| pytest.param( |
| functions.round(col("a").asin(), lit(4)), |
| pa.array([-0.8481, 0.5236, 0.0, None], type=pa.float64()), |
| id="asin", |
| ), |
| pytest.param( |
| functions.round(col("a").sin(), lit(4)), |
| pa.array([-0.6816, 0.4794, 0.0, None], type=pa.float64()), |
| id="sin", |
| ), |
| pytest.param( |
| # Since log10 of negative returns NaN and you can't test NaN for |
| # equivalence, also do an abs() here. |
| functions.round(col("a").abs().log10(), lit(4)), |
| pa.array([-0.1249, -0.301, -float("inf"), None], type=pa.float64()), |
| id="log10", |
| ), |
| pytest.param( |
| col("a").iszero(), |
| pa.array([False, False, True, None], type=pa.bool_()), |
| id="iszero", |
| ), |
| pytest.param( |
| functions.round(col("a").acos(), lit(4)), |
| pa.array([2.4189, 1.0472, 1.5708, None], type=pa.float64()), |
| id="acos", |
| ), |
| pytest.param( |
| col("e").isnan(), |
| pa.array([False, True, False, None], type=pa.bool_()), |
| id="isnan", |
| ), |
| pytest.param( |
| functions.round(col("a").degrees(), lit(4)), |
| pa.array([-42.9718, 28.6479, 0.0, None], type=pa.float64()), |
| id="degrees", |
| ), |
| pytest.param( |
| functions.round(col("a").asinh(), lit(4)), |
| pa.array([-0.6931, 0.4812, 0.0, None], type=pa.float64()), |
| id="asinh", |
| ), |
| pytest.param( |
| col("a").abs(), |
| pa.array([0.75, 0.5, 0.0, None], type=pa.float64()), |
| id="abs", |
| ), |
| pytest.param( |
| functions.round(col("a").exp(), lit(4)), |
| pa.array([0.4724, 1.6487, 1.0, None], type=pa.float64()), |
| id="exp", |
| ), |
| pytest.param( |
| functions.round(col("a").cosh(), lit(4)), |
| pa.array([1.2947, 1.1276, 1.0, None], type=pa.float64()), |
| id="cosh", |
| ), |
| pytest.param( |
| functions.round(col("a").radians(), lit(4)), |
| pa.array([-0.0131, 0.0087, 0.0, None], type=pa.float64()), |
| id="radians", |
| ), |
| pytest.param( |
| functions.round(col("a").abs().sqrt(), lit(4)), |
| pa.array([0.866, 0.7071, 0.0, None], type=pa.float64()), |
| id="sqrt", |
| ), |
| pytest.param( |
| functions.round(col("a").tanh(), lit(4)), |
| pa.array([-0.6351, 0.4621, 0.0, None], type=pa.float64()), |
| id="tanh", |
| ), |
| pytest.param( |
| functions.round(col("a").atan(), lit(4)), |
| pa.array([-0.6435, 0.4636, 0.0, None], type=pa.float64()), |
| id="atan", |
| ), |
| pytest.param( |
| functions.round(col("a").atanh(), lit(4)), |
| pa.array([-0.973, 0.5493, 0.0, None], type=pa.float64()), |
| id="atanh", |
| ), |
| pytest.param( |
| # large numbers cause an integer overflow so divid to make smaller |
| (col("b") / lit(4)).factorial(), |
| pa.array([1, 3628800, 1, None], type=pa.int64()), |
| id="factorial", |
| ), |
| pytest.param( |
| # Valid values of acosh must be >= 1.0 |
| functions.round((col("a").abs() + lit(1.0)).acosh(), lit(4)), |
| pa.array([1.1588, 0.9624, 0.0, None], type=pa.float64()), |
| id="acosh", |
| ), |
| pytest.param( |
| col("a").floor(), |
| pa.array([-1.0, 0.0, 0.0, None], type=pa.float64()), |
| id="floor", |
| ), |
| pytest.param( |
| col("a").ceil(), |
| pa.array([-0.0, 1.0, 0.0, None], type=pa.float64()), |
| id="ceil", |
| ), |
| pytest.param( |
| functions.round(col("a").abs().ln(), lit(4)), |
| pa.array([-0.2877, -0.6931, float("-inf"), None], type=pa.float64()), |
| id="ln", |
| ), |
| pytest.param( |
| functions.round(col("a").tan(), lit(4)), |
| pa.array([-0.9316, 0.5463, 0.0, None], type=pa.float64()), |
| id="tan", |
| ), |
| pytest.param( |
| functions.round(col("a").cbrt(), lit(4)), |
| pa.array([-0.9086, 0.7937, 0.0, None], type=pa.float64()), |
| id="cbrt", |
| ), |
| pytest.param( |
| functions.round(col("a").cos(), lit(4)), |
| pa.array([0.7317, 0.8776, 1.0, None], type=pa.float64()), |
| id="cos", |
| ), |
| pytest.param( |
| functions.round(col("a").sinh(), lit(4)), |
| pa.array([-0.8223, 0.5211, 0.0, None], type=pa.float64()), |
| id="sinh", |
| ), |
| pytest.param( |
| col("a").signum(), |
| pa.array([-1.0, 1.0, 0.0, None], type=pa.float64()), |
| id="signum", |
| ), |
| pytest.param( |
| functions.round(col("a").abs().log2(), lit(4)), |
| pa.array([-0.415, -1.0, float("-inf"), None], type=pa.float64()), |
| id="log2", |
| ), |
| pytest.param( |
| functions.round(col("a").cot(), lit(4)), |
| pa.array([-1.0734, 1.8305, float("inf"), None], type=pa.float64()), |
| id="cot", |
| ), |
| # |
| # String Functions |
| # |
| pytest.param( |
| col("c").reverse(), |
| pa.array(["olleH", " dlrow ", "!", None], type=pa.string()), |
| id="reverse", |
| ), |
| pytest.param( |
| col("c").bit_length(), |
| pa.array([40, 56, 8, None], type=pa.int32()), |
| id="bit_length", |
| ), |
| pytest.param( |
| col("b").to_hex(), |
| pa.array(["ffffffffffffffe2", "2a", "0", None], type=pa.string()), |
| id="to_hex", |
| ), |
| pytest.param( |
| col("c").length(), |
| pa.array([5, 7, 1, None], type=pa.int32()), |
| id="length", |
| ), |
| pytest.param( |
| col("c").lower(), |
| pa.array(["hello", " world ", "!", None], type=pa.string()), |
| id="lower", |
| ), |
| pytest.param( |
| col("c").ascii(), |
| pa.array([72, 32, 33, None], type=pa.int32()), |
| id="ascii", |
| ), |
| pytest.param( |
| col("c").sha512(), |
| pa.array( |
| [ |
| bytes.fromhex( |
| "3615F80C9D293ED7402687F94B22D58E529B8CC7916F8FAC7FDDF7FBD5AF4CF777D3D795A7A00A16BF7E7F3FB9561EE9BAAE480DA9FE7A18769E71886B03F315" |
| ), |
| bytes.fromhex( |
| "A6758FDA3C2F0B554084E18308EA99B94B54EEE8FDA72697CEA7844E524CC2F2F2EE4CC8BAC87D2E3E7222959FE3D0CA1A841761FDC0D1780F6FE9E39E369500" |
| ), |
| bytes.fromhex( |
| "3831A6A6155E509DEE59A7F451EB35324D8F8F2DF6E3708894740F98FDEE23889F4DE5ADB0C5010DFB555CDA77C8AB5DC902094C52DE3278F35A75EBC25F093A" |
| ), |
| None, |
| ], |
| type=pa.binary(), |
| ), |
| id="sha512", |
| ), |
| pytest.param( |
| col("c").sha384(), |
| pa.array( |
| [ |
| bytes.fromhex( |
| "3519FE5AD2C596EFE3E276A6F351B8FC0B03DB861782490D45F7598EBD0AB5FD5520ED102F38C4A5EC834E98668035FC" |
| ), |
| bytes.fromhex( |
| "A6A38A9AE2CFD0D67F49989AD584632BF7D7A07DAD2277E92326A6A0B37F884A871D6173FB342CFE258E375258ACAAEC" |
| ), |
| bytes.fromhex( |
| "1D0EC8C84EE9521E21F06774DE232367B64DE628474CB5B2E372B699A1F55AE335CC37193EF823E33324DFD9A70738A6" |
| ), |
| None, |
| ], |
| type=pa.binary(), |
| ), |
| id="sha384", |
| ), |
| pytest.param( |
| col("c").sha256(), |
| pa.array( |
| [ |
| bytes.fromhex( |
| "185F8DB32271FE25F561A6FC938B2E264306EC304EDA518007D1764826381969" |
| ), |
| bytes.fromhex( |
| "DE2EF0D77D456EC1CDE2C52F75996F6636A64079297213D548D875A488B03A75" |
| ), |
| bytes.fromhex( |
| "BB7208BC9B5D7C04F1236A82A0093A5E33F40423D5BA8D4266F7092C3BA43B62" |
| ), |
| None, |
| ], |
| type=pa.binary(), |
| ), |
| id="sha256", |
| ), |
| pytest.param( |
| col("c").sha224(), |
| pa.array( |
| [ |
| bytes.fromhex( |
| "4149DA18AA8BFC2B1E382C6C26556D01A92C261B6436DAD5E3BE3FCC" |
| ), |
| bytes.fromhex( |
| "AD6DF6D9ECDDF50AF2A72D5E3144BA813EE954537572C0E8AB3066BE" |
| ), |
| bytes.fromhex( |
| "6641A7E8278BCD49E476E7ACAE158F4105B2952D22AEB2E0B9A231A0" |
| ), |
| None, |
| ], |
| type=pa.binary(), |
| ), |
| id="sha224", |
| ), |
| pytest.param( |
| col("c").btrim(), |
| pa.array(["Hello", "world", "!", None], type=pa.string_view()), |
| id="btrim", |
| ), |
| pytest.param( |
| col("c").trim(), |
| pa.array(["Hello", "world", "!", None], type=pa.string_view()), |
| id="trim", |
| ), |
| pytest.param( |
| col("c").md5(), |
| pa.array( |
| [ |
| "8b1a9953c4611296a827abf8c47804d7", |
| "de802497c24568d9a85d4eb8c2b6e8fe", |
| "9033e0e305f247c0c3c80d0c7848c8b3", |
| None, |
| ], |
| type=pa.string_view(), |
| ), |
| id="md5", |
| ), |
| pytest.param( |
| col("c").octet_length(), |
| pa.array([5, 7, 1, None], type=pa.int32()), |
| id="octet_length", |
| ), |
| pytest.param( |
| col("c").character_length(), |
| pa.array([5, 7, 1, None], type=pa.int32()), |
| id="character_length", |
| ), |
| pytest.param( |
| col("c").char_length(), |
| pa.array([5, 7, 1, None], type=pa.int32()), |
| id="char_length", |
| ), |
| pytest.param( |
| col("c").rtrim(), |
| pa.array(["Hello", " world", "!", None], type=pa.string_view()), |
| id="rtrim", |
| ), |
| pytest.param( |
| col("c").ltrim(), |
| pa.array(["Hello", "world ", "!", None], type=pa.string_view()), |
| id="ltrim", |
| ), |
| pytest.param( |
| col("c").upper(), |
| pa.array(["HELLO", " WORLD ", "!", None], type=pa.string()), |
| id="upper", |
| ), |
| pytest.param( |
| lit(65).chr(), |
| pa.array(["A", "A", "A", "A"], type=pa.string()), |
| id="chr", |
| ), |
| # |
| # Time Functions |
| # |
| pytest.param( |
| col("b").from_unixtime(), |
| pa.array( |
| [ |
| datetime(1969, 12, 31, 23, 59, 30, tzinfo=timezone.utc), |
| datetime(1970, 1, 1, 0, 0, 42, tzinfo=timezone.utc), |
| datetime(1970, 1, 1, 0, 0, 0, tzinfo=timezone.utc), |
| None, |
| ], |
| type=pa.timestamp("s"), |
| ), |
| id="from_unixtime", |
| ), |
| pytest.param( |
| col("c").initcap(), |
| pa.array(["Hello", " World ", "!", None], type=pa.string_view()), |
| id="initcap", |
| ), |
| # |
| # Array Functions |
| # |
| pytest.param( |
| col("d").array_pop_back(), |
| pa.array([[-1, 1], [5, 10, 15], [], None], type=pa.list_(pa.int64())), |
| id="array_pop_back", |
| ), |
| pytest.param( |
| col("d").array_pop_front(), |
| pa.array([[1, 0], [10, 15, 20], [], None], type=pa.list_(pa.int64())), |
| id="array_pop_front", |
| ), |
| pytest.param( |
| col("d").array_length(), |
| pa.array([3, 4, 0, None], type=pa.uint64()), |
| id="array_length", |
| ), |
| pytest.param( |
| col("d").list_length(), |
| pa.array([3, 4, 0, None], type=pa.uint64()), |
| id="list_length", |
| ), |
| pytest.param( |
| col("d").array_ndims(), |
| pa.array([1, 1, 1, None], type=pa.uint64()), |
| id="array_ndims", |
| ), |
| pytest.param( |
| col("d").list_ndims(), |
| pa.array([1, 1, 1, None], type=pa.uint64()), |
| id="list_ndims", |
| ), |
| pytest.param( |
| col("d").array_dims(), |
| pa.array([[3], [4], None, None], type=pa.list_(pa.uint64())), |
| id="array_dims", |
| ), |
| pytest.param( |
| col("d").array_empty(), |
| pa.array([False, False, True, None], type=pa.bool_()), |
| id="array_empty", |
| ), |
| pytest.param( |
| col("d").list_distinct(), |
| pa.array( |
| [[-1, 0, 1], [5, 10, 15, 20], [], None], type=pa.list_(pa.int64()) |
| ), |
| id="list_distinct", |
| ), |
| pytest.param( |
| col("d").array_distinct(), |
| pa.array( |
| [[-1, 0, 1], [5, 10, 15, 20], [], None], type=pa.list_(pa.int64()) |
| ), |
| id="array_distinct", |
| ), |
| pytest.param( |
| col("d").cardinality(), |
| pa.array([3, 4, None, None], type=pa.uint64()), |
| id="cardinality", |
| ), |
| pytest.param( |
| col("f").flatten(), |
| pa.array( |
| [[-1, 1, 0, 4, 4], [5, 10, 15, 20, 3], [], []], |
| type=pa.list_(pa.int64()), |
| ), |
| id="flatten", |
| ), |
| pytest.param( |
| col("d").list_dims(), |
| pa.array([[3], [4], None, None], type=pa.list_(pa.uint64())), |
| id="list_dims", |
| ), |
| pytest.param( |
| col("d").empty(), |
| pa.array([False, False, True, None], type=pa.bool_()), |
| id="empty", |
| ), |
| # |
| # Other Tests |
| # |
| pytest.param( |
| col("d").arrow_typeof(), |
| pa.array( |
| [ |
| 'List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })', # noqa: E501 |
| 'List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })', # noqa: E501 |
| 'List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })', # noqa: E501 |
| 'List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })', # noqa: E501 |
| ], |
| type=pa.string(), |
| ), |
| id="arrow_typeof", |
| ), |
| ], |
| ) |
| def test_expr_functions(ctx, function, expected_result): |
| batch = pa.RecordBatch.from_arrays( |
| [ |
| pa.array([-0.75, 0.5, 0.0, None], type=pa.float64()), |
| pa.array([-30, 42, 0, None], type=pa.int64()), |
| pa.array(["Hello", " world ", "!", None], type=pa.string_view()), |
| pa.array( |
| [[-1, 1, 0], [5, 10, 15, 20], [], None], type=pa.list_(pa.int64()) |
| ), |
| pa.array([-0.75, float("nan"), 0.0, None], type=pa.float64()), |
| pa.array( |
| [[[-1, 1, 0], [4, 4]], [[5, 10, 15, 20], [3]], [[]], [None]], |
| type=pa.list_(pa.list_(pa.int64())), |
| ), |
| ], |
| names=["a", "b", "c", "d", "e", "f"], |
| ) |
| df = ctx.create_dataframe([[batch]]).select(function) |
| result = df.collect() |
| |
| assert len(result) == 1 |
| assert result[0].column(0).equals(expected_result) |
| |
| |
| def test_literal_metadata(ctx): |
| result = ( |
| ctx.from_pydict({"a": [1]}) |
| .select( |
| lit(1).alias("no_metadata"), |
| lit_with_metadata(2, {"key1": "value1"}).alias("lit_with_metadata_fn"), |
| literal_with_metadata(3, {"key2": "value2"}).alias( |
| "literal_with_metadata_fn" |
| ), |
| ) |
| .collect() |
| ) |
| |
| expected_schema = pa.schema( |
| [ |
| pa.field("no_metadata", pa.int64(), nullable=False), |
| pa.field( |
| "lit_with_metadata_fn", |
| pa.int64(), |
| nullable=False, |
| metadata={"key1": "value1"}, |
| ), |
| pa.field( |
| "literal_with_metadata_fn", |
| pa.int64(), |
| nullable=False, |
| metadata={"key2": "value2"}, |
| ), |
| ] |
| ) |
| |
| expected = pa.RecordBatch.from_pydict( |
| { |
| "no_metadata": pa.array([1]), |
| "lit_with_metadata_fn": pa.array([2]), |
| "literal_with_metadata_fn": pa.array([3]), |
| }, |
| schema=expected_schema, |
| ) |
| |
| assert result[0] == expected |
| |
| # Testing result[0].schema == expected_schema does not check each key/value pair |
| # so we want to explicitly test these |
| for expected_field in expected_schema: |
| actual_field = result[0].schema.field(expected_field.name) |
| assert expected_field.metadata == actual_field.metadata |
| |
| |
| def test_ensure_expr(): |
| e = col("a") |
| assert ensure_expr(e) is e.expr |
| with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): |
| ensure_expr("a") |
| |
| |
| def test_ensure_expr_list_string(): |
| with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): |
| ensure_expr_list("a") |
| |
| |
| def test_ensure_expr_list_bytes(): |
| with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): |
| ensure_expr_list(b"a") |
| |
| |
| def test_ensure_expr_list_bytearray(): |
| with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)): |
| ensure_expr_list(bytearray(b"a")) |