| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import datetime |
| import pytest |
| |
| import pyarrow as pa |
| |
| |
| @pytest.mark.gandiva |
| def test_tree_exp_builder(): |
| import pyarrow.gandiva as gandiva |
| |
| builder = gandiva.TreeExprBuilder() |
| |
| field_a = pa.field('a', pa.int32()) |
| field_b = pa.field('b', pa.int32()) |
| |
| schema = pa.schema([field_a, field_b]) |
| |
| field_result = pa.field('res', pa.int32()) |
| |
| node_a = builder.make_field(field_a) |
| node_b = builder.make_field(field_b) |
| |
| condition = builder.make_function("greater_than", [node_a, node_b], |
| pa.bool_()) |
| if_node = builder.make_if(condition, node_a, node_b, pa.int32()) |
| |
| expr = builder.make_expression(if_node, field_result) |
| |
| projector = gandiva.make_projector( |
| schema, [expr], pa.default_memory_pool()) |
| |
| a = pa.array([10, 12, -20, 5], type=pa.int32()) |
| b = pa.array([5, 15, 15, 17], type=pa.int32()) |
| e = pa.array([10, 15, 15, 17], type=pa.int32()) |
| input_batch = pa.RecordBatch.from_arrays([a, b], names=['a', 'b']) |
| |
| r, = projector.evaluate(input_batch) |
| assert r.equals(e) |
| |
| |
| @pytest.mark.gandiva |
| def test_table(): |
| import pyarrow.gandiva as gandiva |
| |
| table = pa.Table.from_arrays([pa.array([1.0, 2.0]), pa.array([3.0, 4.0])], |
| ['a', 'b']) |
| |
| builder = gandiva.TreeExprBuilder() |
| node_a = builder.make_field(table.schema.field_by_name("a")) |
| node_b = builder.make_field(table.schema.field_by_name("b")) |
| |
| sum = builder.make_function("add", [node_a, node_b], pa.float64()) |
| |
| field_result = pa.field("c", pa.float64()) |
| expr = builder.make_expression(sum, field_result) |
| |
| projector = gandiva.make_projector( |
| table.schema, [expr], pa.default_memory_pool()) |
| |
| # TODO: Add .evaluate function which can take Tables instead of |
| # RecordBatches |
| r, = projector.evaluate(table.to_batches()[0]) |
| |
| e = pa.array([4.0, 6.0]) |
| assert r.equals(e) |
| |
| |
| @pytest.mark.gandiva |
| def test_filter(): |
| import pyarrow.gandiva as gandiva |
| |
| table = pa.Table.from_arrays([pa.array([1.0 * i for i in range(10000)])], |
| ['a']) |
| |
| builder = gandiva.TreeExprBuilder() |
| node_a = builder.make_field(table.schema.field_by_name("a")) |
| thousand = builder.make_literal(1000.0, pa.float64()) |
| cond = builder.make_function("less_than", [node_a, thousand], pa.bool_()) |
| condition = builder.make_condition(cond) |
| |
| filter = gandiva.make_filter(table.schema, condition) |
| result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) |
| assert result.to_array().equals(pa.array(range(1000), type=pa.uint32())) |
| |
| |
| @pytest.mark.gandiva |
| def test_in_expr(): |
| import pyarrow.gandiva as gandiva |
| |
| arr = pa.array([u"ga", u"an", u"nd", u"di", u"iv", u"va"]) |
| table = pa.Table.from_arrays([arr], ["a"]) |
| |
| # string |
| builder = gandiva.TreeExprBuilder() |
| node_a = builder.make_field(table.schema.field_by_name("a")) |
| cond = builder.make_in_expression(node_a, [u"an", u"nd"], pa.string()) |
| condition = builder.make_condition(cond) |
| filter = gandiva.make_filter(table.schema, condition) |
| result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) |
| assert list(result.to_array()) == [1, 2] |
| |
| # int32 |
| arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4]) |
| table = pa.Table.from_arrays([arr.cast(pa.int32())], ["a"]) |
| node_a = builder.make_field(table.schema.field_by_name("a")) |
| cond = builder.make_in_expression(node_a, [1, 5], pa.int32()) |
| condition = builder.make_condition(cond) |
| filter = gandiva.make_filter(table.schema, condition) |
| result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) |
| assert list(result.to_array()) == [1, 3, 4, 8] |
| |
| # int64 |
| arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4]) |
| table = pa.Table.from_arrays([arr], ["a"]) |
| node_a = builder.make_field(table.schema.field_by_name("a")) |
| cond = builder.make_in_expression(node_a, [1, 5], pa.int64()) |
| condition = builder.make_condition(cond) |
| filter = gandiva.make_filter(table.schema, condition) |
| result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) |
| assert list(result.to_array()) == [1, 3, 4, 8] |
| |
| |
| @pytest.mark.skip(reason="Gandiva C++ did not have *real* binary, " |
| "time and date support.") |
| def test_in_expr_todo(): |
| import pyarrow.gandiva as gandiva |
| # TODO: Implement reasonable support for timestamp, time & date. |
| # Current exceptions: |
| # pyarrow.lib.ArrowException: ExpressionValidationError: |
| # Evaluation expression for IN clause returns XXXX values are of typeXXXX |
| |
| # binary |
| arr = pa.array([b"ga", b"an", b"nd", b"di", b"iv", b"va"]) |
| table = pa.Table.from_arrays([arr], ["a"]) |
| |
| builder = gandiva.TreeExprBuilder() |
| node_a = builder.make_field(table.schema.field_by_name("a")) |
| cond = builder.make_in_expression(node_a, [b'an', b'nd'], pa.binary()) |
| condition = builder.make_condition(cond) |
| |
| filter = gandiva.make_filter(table.schema, condition) |
| result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) |
| assert list(result.to_array()) == [1, 2] |
| |
| # timestamp |
| datetime_1 = datetime.datetime.utcfromtimestamp(1542238951.621877) |
| datetime_2 = datetime.datetime.utcfromtimestamp(1542238911.621877) |
| datetime_3 = datetime.datetime.utcfromtimestamp(1542238051.621877) |
| |
| arr = pa.array([datetime_1, datetime_2, datetime_3]) |
| table = pa.Table.from_arrays([arr], ["a"]) |
| |
| builder = gandiva.TreeExprBuilder() |
| node_a = builder.make_field(table.schema.field_by_name("a")) |
| cond = builder.make_in_expression(node_a, [datetime_2], pa.timestamp('ms')) |
| condition = builder.make_condition(cond) |
| |
| filter = gandiva.make_filter(table.schema, condition) |
| result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) |
| assert list(result.to_array()) == [1] |
| |
| # time |
| time_1 = datetime_1.time() |
| time_2 = datetime_2.time() |
| time_3 = datetime_3.time() |
| |
| arr = pa.array([time_1, time_2, time_3]) |
| table = pa.Table.from_arrays([arr], ["a"]) |
| |
| builder = gandiva.TreeExprBuilder() |
| node_a = builder.make_field(table.schema.field_by_name("a")) |
| cond = builder.make_in_expression(node_a, [time_2], pa.time64('ms')) |
| condition = builder.make_condition(cond) |
| |
| filter = gandiva.make_filter(table.schema, condition) |
| result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) |
| assert list(result.to_array()) == [1] |
| |
| # date |
| date_1 = datetime_1.date() |
| date_2 = datetime_2.date() |
| date_3 = datetime_3.date() |
| |
| arr = pa.array([date_1, date_2, date_3]) |
| table = pa.Table.from_arrays([arr], ["a"]) |
| |
| builder = gandiva.TreeExprBuilder() |
| node_a = builder.make_field(table.schema.field_by_name("a")) |
| cond = builder.make_in_expression(node_a, [date_2], pa.date32()) |
| condition = builder.make_condition(cond) |
| |
| filter = gandiva.make_filter(table.schema, condition) |
| result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) |
| assert list(result.to_array()) == [1] |
| |
| |
| @pytest.mark.gandiva |
| def test_boolean(): |
| import pyarrow.gandiva as gandiva |
| |
| table = pa.Table.from_arrays([pa.array([1., 31., 46., 3., 57., 44., 22.]), |
| pa.array([5., 45., 36., 73., |
| 83., 23., 76.])], |
| ['a', 'b']) |
| |
| builder = gandiva.TreeExprBuilder() |
| node_a = builder.make_field(table.schema.field_by_name("a")) |
| node_b = builder.make_field(table.schema.field_by_name("b")) |
| fifty = builder.make_literal(50.0, pa.float64()) |
| eleven = builder.make_literal(11.0, pa.float64()) |
| |
| cond_1 = builder.make_function("less_than", [node_a, fifty], pa.bool_()) |
| cond_2 = builder.make_function("greater_than", [node_a, node_b], |
| pa.bool_()) |
| cond_3 = builder.make_function("less_than", [node_b, eleven], pa.bool_()) |
| cond = builder.make_or([builder.make_and([cond_1, cond_2]), cond_3]) |
| condition = builder.make_condition(cond) |
| |
| filter = gandiva.make_filter(table.schema, condition) |
| result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) |
| assert list(result.to_array()) == [0, 2, 5] |
| |
| |
| @pytest.mark.gandiva |
| def test_literals(): |
| import pyarrow.gandiva as gandiva |
| |
| builder = gandiva.TreeExprBuilder() |
| |
| builder.make_literal(True, pa.bool_()) |
| builder.make_literal(0, pa.uint8()) |
| builder.make_literal(1, pa.uint16()) |
| builder.make_literal(2, pa.uint32()) |
| builder.make_literal(3, pa.uint64()) |
| builder.make_literal(4, pa.int8()) |
| builder.make_literal(5, pa.int16()) |
| builder.make_literal(6, pa.int32()) |
| builder.make_literal(7, pa.int64()) |
| builder.make_literal(8.0, pa.float32()) |
| builder.make_literal(9.0, pa.float64()) |
| builder.make_literal("hello", pa.string()) |
| builder.make_literal(b"world", pa.binary()) |
| |
| builder.make_literal(True, "bool") |
| builder.make_literal(0, "uint8") |
| builder.make_literal(1, "uint16") |
| builder.make_literal(2, "uint32") |
| builder.make_literal(3, "uint64") |
| builder.make_literal(4, "int8") |
| builder.make_literal(5, "int16") |
| builder.make_literal(6, "int32") |
| builder.make_literal(7, "int64") |
| builder.make_literal(8.0, "float32") |
| builder.make_literal(9.0, "float64") |
| builder.make_literal("hello", "string") |
| builder.make_literal(b"world", "binary") |
| |
| with pytest.raises(TypeError): |
| builder.make_literal("hello", pa.int64()) |
| with pytest.raises(TypeError): |
| builder.make_literal(True, None) |
| |
| |
| @pytest.mark.gandiva |
| def test_regex(): |
| import pyarrow.gandiva as gandiva |
| |
| elements = ["park", "sparkle", "bright spark and fire", "spark"] |
| data = pa.array(elements, type=pa.string()) |
| table = pa.Table.from_arrays([data], names=['a']) |
| |
| builder = gandiva.TreeExprBuilder() |
| node_a = builder.make_field(table.schema.field_by_name("a")) |
| regex = builder.make_literal("%spark%", pa.string()) |
| like = builder.make_function("like", [node_a, regex], pa.bool_()) |
| |
| field_result = pa.field("b", pa.bool_()) |
| expr = builder.make_expression(like, field_result) |
| |
| projector = gandiva.make_projector( |
| table.schema, [expr], pa.default_memory_pool()) |
| |
| r, = projector.evaluate(table.to_batches()[0]) |
| b = pa.array([False, True, True, True], type=pa.bool_()) |
| assert r.equals(b) |
| |
| |
| @pytest.mark.gandiva |
| def test_get_registered_function_signatures(): |
| import pyarrow.gandiva as gandiva |
| signatures = gandiva.get_registered_function_signatures() |
| |
| assert type(signatures[0].return_type()) is pa.DataType |
| assert type(signatures[0].param_types()) is list |
| assert hasattr(signatures[0], "name") |