blob: 0374b68d4628896386fe827a35d8a5f0cc24cf67 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <gtest/gtest.h>
#include "arrow/memory_pool.h"
#include "gandiva/filter.h"
#include "gandiva/projector.h"
#include "gandiva/tests/test_util.h"
#include "gandiva/tree_expr_builder.h"
namespace gandiva {
using arrow::boolean;
using arrow::int32;
using arrow::utf8;
class TestNullValidity : public ::testing::Test {
public:
void SetUp() { pool_ = arrow::default_memory_pool(); }
protected:
arrow::MemoryPool* pool_;
};
// Create an array without a validity buffer.
ArrayPtr MakeArrowArrayInt32WithNullValidity(std::vector<int32_t> in_data) {
auto array = MakeArrowArrayInt32(in_data);
return std::make_shared<arrow::Int32Array>(in_data.size(), array->data()->buffers[1],
nullptr, 0);
}
TEST_F(TestNullValidity, TestFunc) {
// schema for input fields
auto field0 = field("f0", int32());
auto field1 = field("f1", int32());
auto schema = arrow::schema({field0, field1});
// Build condition f0 + f1 < 10
auto node_f0 = TreeExprBuilder::MakeField(field0);
auto node_f1 = TreeExprBuilder::MakeField(field1);
auto sum_func =
TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32());
auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10);
auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10},
arrow::boolean());
auto condition = TreeExprBuilder::MakeCondition(less_than_10);
std::shared_ptr<Filter> filter;
auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
EXPECT_TRUE(status.ok());
// Create a row-batch with some sample data
int num_records = 5;
// Create an array without a validity buffer.
auto array0 = MakeArrowArrayInt32WithNullValidity({1, 2, 3, 4, 6});
auto array1 = MakeArrowArrayInt32({5, 9, 6, 17, 3}, {true, true, false, true, true});
// expected output (indices for which condition matches)
auto exp = MakeArrowArrayUint16({0, 4});
// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
std::shared_ptr<SelectionVector> selection_vector;
status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
EXPECT_TRUE(status.ok());
// Evaluate expression
status = filter->Evaluate(*in_batch, selection_vector);
EXPECT_TRUE(status.ok());
// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
}
TEST_F(TestNullValidity, TestIfElse) {
// schema for input fields
auto fielda = field("a", int32());
auto fieldb = field("b", int32());
auto schema = arrow::schema({fielda, fieldb});
// output fields
auto field_result = field("res", int32());
// build expression.
// if (a > b)
// a
// else
// b
auto node_a = TreeExprBuilder::MakeField(fielda);
auto node_b = TreeExprBuilder::MakeField(fieldb);
auto condition =
TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean());
auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32());
auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
// Build a projector for the expressions.
std::shared_ptr<Projector> projector;
auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok());
// Create a row-batch with some sample data
int num_records = 4;
auto array0 = MakeArrowArrayInt32WithNullValidity({10, 12, -20, 5});
auto array1 = MakeArrowArrayInt32({5, 15, 15, 17});
// expected output
auto exp = MakeArrowArrayInt32({10, 15, 15, 17}, {true, true, true, true});
// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok());
// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}
TEST_F(TestNullValidity, TestUtf8) {
// schema for input fields
auto field_a = field("a", utf8());
auto schema = arrow::schema({field_a});
// output fields
auto res = field("res1", int32());
// build expressions.
// length(a)
auto expr = TreeExprBuilder::MakeExpression("length", {field_a}, res);
// Build a projector for the expressions.
std::shared_ptr<Projector> projector;
auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok()) << status.message();
// Create a row-batch with some sample data
int num_records = 5;
auto array_v = MakeArrowArrayUtf8({"foo", "hello", "bye", "hi", "मदन"});
auto array_a = std::make_shared<arrow::StringArray>(
num_records, array_v->data()->buffers[1], array_v->data()->buffers[2]);
// expected output
auto exp = MakeArrowArrayInt32({3, 5, 3, 2, 3}, {true, true, true, true, true});
// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok());
// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}
} // namespace gandiva