blob: 54b6d43b4df1cd52f70df827b4ab59c6f6c605b8 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <gtest/gtest.h>
#include "arrow/memory_pool.h"
#include "arrow/status.h"
#include "gandiva/projector.h"
#include "gandiva/tests/test_util.h"
#include "gandiva/tree_expr_builder.h"
namespace gandiva {
using arrow::boolean;
using arrow::float32;
using arrow::int32;
class TestIfExpr : public ::testing::Test {
public:
void SetUp() { pool_ = arrow::default_memory_pool(); }
protected:
arrow::MemoryPool* pool_;
};
TEST_F(TestIfExpr, TestSimple) {
// schema for input fields
auto fielda = field("a", int32());
auto fieldb = field("b", int32());
auto schema = arrow::schema({fielda, fieldb});
// output fields
auto field_result = field("res", int32());
// build expression.
// if (a > b)
// a
// else
// b
auto node_a = TreeExprBuilder::MakeField(fielda);
auto node_b = TreeExprBuilder::MakeField(fieldb);
auto condition =
TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32());
auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
// Build a projector for the expressions.
std::shared_ptr<Projector> projector;
auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok());
// Create a row-batch with some sample data
int num_records = 4;
auto array0 = MakeArrowArrayInt32({10, 12, -20, 5}, {true, true, true, false});
auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true});
// expected output
auto exp = MakeArrowArrayInt32({10, 15, 15, 17}, {true, true, true, true});
// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok());
// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}
TEST_F(TestIfExpr, TestSimpleArithmetic) {
// schema for input fields
auto fielda = field("a", int32());
auto fieldb = field("b", int32());
auto schema = arrow::schema({fielda, fieldb});
// output fields
auto field_result = field("res", int32());
// build expression.
// if (a > b)
// a + b
// else
// a - b
auto node_a = TreeExprBuilder::MakeField(fielda);
auto node_b = TreeExprBuilder::MakeField(fieldb);
auto condition =
TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
auto sum = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32());
auto sub = TreeExprBuilder::MakeFunction("subtract", {node_a, node_b}, int32());
auto if_node = TreeExprBuilder::MakeIf(condition, sum, sub, int32());
auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
// Build a projector for the expressions.
std::shared_ptr<Projector> projector;
auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok());
// Create a row-batch with some sample data
int num_records = 4;
auto array0 = MakeArrowArrayInt32({10, 12, -20, 5}, {true, true, true, false});
auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true});
// expected output
auto exp = MakeArrowArrayInt32({15, -3, -35, 0}, {true, true, true, false});
// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok());
// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}
TEST_F(TestIfExpr, TestNested) {
// schema for input fields
auto fielda = field("a", int32());
auto fieldb = field("b", int32());
auto schema = arrow::schema({fielda, fieldb});
// output fields
auto field_result = field("res", int32());
// build expression.
// if (a > b)
// a + b
// else if (a < b)
// a - b
// else
// a * b
auto node_a = TreeExprBuilder::MakeField(fielda);
auto node_b = TreeExprBuilder::MakeField(fieldb);
auto condition_gt =
TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
auto condition_lt =
TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean());
auto sum = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32());
auto sub = TreeExprBuilder::MakeFunction("subtract", {node_a, node_b}, int32());
auto mult = TreeExprBuilder::MakeFunction("multiply", {node_a, node_b}, int32());
auto else_node = TreeExprBuilder::MakeIf(condition_lt, sub, mult, int32());
auto if_node = TreeExprBuilder::MakeIf(condition_gt, sum, else_node, int32());
auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
// Build a projector for the expressions.
std::shared_ptr<Projector> projector;
auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok());
// Create a row-batch with some sample data
int num_records = 4;
auto array0 = MakeArrowArrayInt32({10, 12, 15, 5}, {true, true, true, false});
auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true});
// expected output
auto exp = MakeArrowArrayInt32({15, -3, 225, 0}, {true, true, true, false});
// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok());
// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}
TEST_F(TestIfExpr, TestNestedInIf) {
// schema for input fields
auto fielda = field("a", int32());
auto fieldb = field("b", int32());
auto fieldc = field("c", int32());
auto schema = arrow::schema({fielda, fieldb, fieldc});
// output fields
auto field_result = field("res", int32());
// build expression.
// if (a > 10)
// if (a < 20)
// a + b
// else
// b + c
// else
// a + c
auto node_a = TreeExprBuilder::MakeField(fielda);
auto node_b = TreeExprBuilder::MakeField(fieldb);
auto node_c = TreeExprBuilder::MakeField(fieldc);
auto literal_10 = TreeExprBuilder::MakeLiteral(10);
auto literal_20 = TreeExprBuilder::MakeLiteral(20);
auto gt_10 =
TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_10}, boolean());
auto lt_20 =
TreeExprBuilder::MakeFunction("less_than", {node_a, literal_20}, boolean());
auto sum_ab = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32());
auto sum_bc = TreeExprBuilder::MakeFunction("add", {node_b, node_c}, int32());
auto sum_ac = TreeExprBuilder::MakeFunction("add", {node_a, node_c}, int32());
auto if_lt_20 = TreeExprBuilder::MakeIf(lt_20, sum_ab, sum_bc, int32());
auto if_gt_10 = TreeExprBuilder::MakeIf(gt_10, if_lt_20, sum_ac, int32());
auto expr = TreeExprBuilder::MakeExpression(if_gt_10, field_result);
// Build a projector for the expressions.
std::shared_ptr<Projector> projector;
auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok());
// Create a row-batch with some sample data
int num_records = 6;
auto array_a =
MakeArrowArrayInt32({21, 15, 5, 22, 15, 5}, {true, true, true, true, true, true});
auto array_b = MakeArrowArrayInt32({20, 18, 19, 20, 18, 19},
{true, true, true, false, false, false});
auto array_c = MakeArrowArrayInt32({35, 45, 55, 35, 45, 55},
{true, true, true, false, false, false});
// expected output
auto exp =
MakeArrowArrayInt32({55, 33, 60, 0, 0, 0}, {true, true, true, false, false, false});
// prepare input record batch
auto in_batch =
arrow::RecordBatch::Make(schema, num_records, {array_a, array_b, array_c});
// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok());
// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}
TEST_F(TestIfExpr, TestNestedInCondition) {
// schema for input fields
auto fielda = field("a", int32());
auto fieldb = field("b", int32());
auto schema = arrow::schema({fielda, fieldb});
// output fields
auto field_result = field("res", int32());
// build expression.
// if (if (a > b) then true else if (a < b) false else null)
// 1
// else if !(if (a > b) then true else if (a < b) false else null)
// 2
// else
// 3
auto node_a = TreeExprBuilder::MakeField(fielda);
auto node_b = TreeExprBuilder::MakeField(fieldb);
auto literal_1 = TreeExprBuilder::MakeLiteral(1);
auto literal_2 = TreeExprBuilder::MakeLiteral(2);
auto literal_3 = TreeExprBuilder::MakeLiteral(3);
auto literal_true = TreeExprBuilder::MakeLiteral(true);
auto literal_false = TreeExprBuilder::MakeLiteral(false);
auto literal_null = TreeExprBuilder::MakeNull(boolean());
auto a_gt_b =
TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
auto a_lt_b = TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean());
auto cond_else =
TreeExprBuilder::MakeIf(a_lt_b, literal_false, literal_null, boolean());
auto cond_if = TreeExprBuilder::MakeIf(a_gt_b, literal_true, cond_else, boolean());
auto not_cond_if = TreeExprBuilder::MakeFunction("not", {cond_if}, boolean());
auto outer_else = TreeExprBuilder::MakeIf(not_cond_if, literal_2, literal_3, int32());
auto outer_if = TreeExprBuilder::MakeIf(cond_if, literal_1, outer_else, int32());
auto expr = TreeExprBuilder::MakeExpression(outer_if, field_result);
// Build a projector for the expressions.
std::shared_ptr<Projector> projector;
auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok());
// Create a row-batch with some sample data
int num_records = 6;
auto array_a =
MakeArrowArrayInt32({21, 15, 5, 22, 15, 5}, {true, true, true, true, true, true});
auto array_b = MakeArrowArrayInt32({20, 18, 19, 20, 18, 19},
{true, true, true, false, false, false});
// expected output
auto exp =
MakeArrowArrayInt32({1, 2, 2, 3, 3, 3}, {true, true, true, true, true, true});
// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b});
// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok());
// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}
TEST_F(TestIfExpr, TestBigNested) {
// schema for input fields
auto fielda = field("a", int32());
auto schema = arrow::schema({fielda});
// output fields
auto field_result = field("res", int32());
// build expression.
// if (a < 10)
// 10
// else if (a < 20)
// 20
// ..
// ..
// else if (a < 190)
// 190
// else
// 200
auto node_a = TreeExprBuilder::MakeField(fielda);
auto top_node = TreeExprBuilder::MakeLiteral(200);
for (int thresh = 190; thresh > 0; thresh -= 10) {
auto literal = TreeExprBuilder::MakeLiteral(thresh);
auto condition =
TreeExprBuilder::MakeFunction("less_than", {node_a, literal}, boolean());
auto if_node = TreeExprBuilder::MakeIf(condition, literal, top_node, int32());
top_node = if_node;
}
auto expr = TreeExprBuilder::MakeExpression(top_node, field_result);
// Build a projector for the expressions.
std::shared_ptr<Projector> projector;
auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok());
// Create a row-batch with some sample data
int num_records = 4;
auto array0 = MakeArrowArrayInt32({10, 102, 158, 302}, {true, true, true, true});
// expected output
auto exp = MakeArrowArrayInt32({20, 110, 160, 200}, {true, true, true, true});
// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok());
// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}
} // namespace gandiva