blob: 34b3cb74bb4df96e53064639ef39f00a0829855b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import pytest
import pyarrow as pa
@pytest.mark.gandiva
def test_tree_exp_builder():
import pyarrow.gandiva as gandiva
builder = gandiva.TreeExprBuilder()
field_a = pa.field('a', pa.int32())
field_b = pa.field('b', pa.int32())
schema = pa.schema([field_a, field_b])
field_result = pa.field('res', pa.int32())
node_a = builder.make_field(field_a)
node_b = builder.make_field(field_b)
condition = builder.make_function("greater_than", [node_a, node_b],
pa.bool_())
if_node = builder.make_if(condition, node_a, node_b, pa.int32())
expr = builder.make_expression(if_node, field_result)
projector = gandiva.make_projector(
schema, [expr], pa.default_memory_pool())
a = pa.array([10, 12, -20, 5], type=pa.int32())
b = pa.array([5, 15, 15, 17], type=pa.int32())
e = pa.array([10, 15, 15, 17], type=pa.int32())
input_batch = pa.RecordBatch.from_arrays([a, b], names=['a', 'b'])
r, = projector.evaluate(input_batch)
assert r.equals(e)
@pytest.mark.gandiva
def test_table():
import pyarrow.gandiva as gandiva
table = pa.Table.from_arrays([pa.array([1.0, 2.0]), pa.array([3.0, 4.0])],
['a', 'b'])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field_by_name("a"))
node_b = builder.make_field(table.schema.field_by_name("b"))
sum = builder.make_function("add", [node_a, node_b], pa.float64())
field_result = pa.field("c", pa.float64())
expr = builder.make_expression(sum, field_result)
projector = gandiva.make_projector(
table.schema, [expr], pa.default_memory_pool())
# TODO: Add .evaluate function which can take Tables instead of
# RecordBatches
r, = projector.evaluate(table.to_batches()[0])
e = pa.array([4.0, 6.0])
assert r.equals(e)
@pytest.mark.gandiva
def test_filter():
import pyarrow.gandiva as gandiva
table = pa.Table.from_arrays([pa.array([1.0 * i for i in range(10000)])],
['a'])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field_by_name("a"))
thousand = builder.make_literal(1000.0, pa.float64())
cond = builder.make_function("less_than", [node_a, thousand], pa.bool_())
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert result.to_array().equals(pa.array(range(1000), type=pa.uint32()))
@pytest.mark.gandiva
def test_in_expr():
import pyarrow.gandiva as gandiva
arr = pa.array([u"ga", u"an", u"nd", u"di", u"iv", u"va"])
table = pa.Table.from_arrays([arr], ["a"])
# string
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field_by_name("a"))
cond = builder.make_in_expression(node_a, [u"an", u"nd"], pa.string())
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert list(result.to_array()) == [1, 2]
# int32
arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4])
table = pa.Table.from_arrays([arr.cast(pa.int32())], ["a"])
node_a = builder.make_field(table.schema.field_by_name("a"))
cond = builder.make_in_expression(node_a, [1, 5], pa.int32())
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert list(result.to_array()) == [1, 3, 4, 8]
# int64
arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4])
table = pa.Table.from_arrays([arr], ["a"])
node_a = builder.make_field(table.schema.field_by_name("a"))
cond = builder.make_in_expression(node_a, [1, 5], pa.int64())
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert list(result.to_array()) == [1, 3, 4, 8]
@pytest.mark.skip(reason="Gandiva C++ did not have *real* binary, "
"time and date support.")
def test_in_expr_todo():
import pyarrow.gandiva as gandiva
# TODO: Implement reasonable support for timestamp, time & date.
# Current exceptions:
# pyarrow.lib.ArrowException: ExpressionValidationError:
# Evaluation expression for IN clause returns XXXX values are of typeXXXX
# binary
arr = pa.array([b"ga", b"an", b"nd", b"di", b"iv", b"va"])
table = pa.Table.from_arrays([arr], ["a"])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field_by_name("a"))
cond = builder.make_in_expression(node_a, [b'an', b'nd'], pa.binary())
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert list(result.to_array()) == [1, 2]
# timestamp
datetime_1 = datetime.datetime.utcfromtimestamp(1542238951.621877)
datetime_2 = datetime.datetime.utcfromtimestamp(1542238911.621877)
datetime_3 = datetime.datetime.utcfromtimestamp(1542238051.621877)
arr = pa.array([datetime_1, datetime_2, datetime_3])
table = pa.Table.from_arrays([arr], ["a"])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field_by_name("a"))
cond = builder.make_in_expression(node_a, [datetime_2], pa.timestamp('ms'))
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert list(result.to_array()) == [1]
# time
time_1 = datetime_1.time()
time_2 = datetime_2.time()
time_3 = datetime_3.time()
arr = pa.array([time_1, time_2, time_3])
table = pa.Table.from_arrays([arr], ["a"])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field_by_name("a"))
cond = builder.make_in_expression(node_a, [time_2], pa.time64('ms'))
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert list(result.to_array()) == [1]
# date
date_1 = datetime_1.date()
date_2 = datetime_2.date()
date_3 = datetime_3.date()
arr = pa.array([date_1, date_2, date_3])
table = pa.Table.from_arrays([arr], ["a"])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field_by_name("a"))
cond = builder.make_in_expression(node_a, [date_2], pa.date32())
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert list(result.to_array()) == [1]
@pytest.mark.gandiva
def test_boolean():
import pyarrow.gandiva as gandiva
table = pa.Table.from_arrays([pa.array([1., 31., 46., 3., 57., 44., 22.]),
pa.array([5., 45., 36., 73.,
83., 23., 76.])],
['a', 'b'])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field_by_name("a"))
node_b = builder.make_field(table.schema.field_by_name("b"))
fifty = builder.make_literal(50.0, pa.float64())
eleven = builder.make_literal(11.0, pa.float64())
cond_1 = builder.make_function("less_than", [node_a, fifty], pa.bool_())
cond_2 = builder.make_function("greater_than", [node_a, node_b],
pa.bool_())
cond_3 = builder.make_function("less_than", [node_b, eleven], pa.bool_())
cond = builder.make_or([builder.make_and([cond_1, cond_2]), cond_3])
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert list(result.to_array()) == [0, 2, 5]
@pytest.mark.gandiva
def test_literals():
import pyarrow.gandiva as gandiva
builder = gandiva.TreeExprBuilder()
builder.make_literal(True, pa.bool_())
builder.make_literal(0, pa.uint8())
builder.make_literal(1, pa.uint16())
builder.make_literal(2, pa.uint32())
builder.make_literal(3, pa.uint64())
builder.make_literal(4, pa.int8())
builder.make_literal(5, pa.int16())
builder.make_literal(6, pa.int32())
builder.make_literal(7, pa.int64())
builder.make_literal(8.0, pa.float32())
builder.make_literal(9.0, pa.float64())
builder.make_literal("hello", pa.string())
builder.make_literal(b"world", pa.binary())
builder.make_literal(True, "bool")
builder.make_literal(0, "uint8")
builder.make_literal(1, "uint16")
builder.make_literal(2, "uint32")
builder.make_literal(3, "uint64")
builder.make_literal(4, "int8")
builder.make_literal(5, "int16")
builder.make_literal(6, "int32")
builder.make_literal(7, "int64")
builder.make_literal(8.0, "float32")
builder.make_literal(9.0, "float64")
builder.make_literal("hello", "string")
builder.make_literal(b"world", "binary")
with pytest.raises(TypeError):
builder.make_literal("hello", pa.int64())
with pytest.raises(TypeError):
builder.make_literal(True, None)
@pytest.mark.gandiva
def test_regex():
import pyarrow.gandiva as gandiva
elements = ["park", "sparkle", "bright spark and fire", "spark"]
data = pa.array(elements, type=pa.string())
table = pa.Table.from_arrays([data], names=['a'])
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field_by_name("a"))
regex = builder.make_literal("%spark%", pa.string())
like = builder.make_function("like", [node_a, regex], pa.bool_())
field_result = pa.field("b", pa.bool_())
expr = builder.make_expression(like, field_result)
projector = gandiva.make_projector(
table.schema, [expr], pa.default_memory_pool())
r, = projector.evaluate(table.to_batches()[0])
b = pa.array([False, True, True, True], type=pa.bool_())
assert r.equals(b)
@pytest.mark.gandiva
def test_get_registered_function_signatures():
import pyarrow.gandiva as gandiva
signatures = gandiva.get_registered_function_signatures()
assert type(signatures[0].return_type()) is pa.DataType
assert type(signatures[0].param_types()) is list
assert hasattr(signatures[0], "name")