| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #include <gtest/gtest.h> |
| #include "arrow/memory_pool.h" |
| #include "gandiva/filter.h" |
| #include "gandiva/projector.h" |
| #include "gandiva/tests/test_util.h" |
| #include "gandiva/tree_expr_builder.h" |
| |
| namespace gandiva { |
| |
| using arrow::boolean; |
| using arrow::int32; |
| using arrow::utf8; |
| |
| class TestNullValidity : public ::testing::Test { |
| public: |
| void SetUp() { pool_ = arrow::default_memory_pool(); } |
| |
| protected: |
| arrow::MemoryPool* pool_; |
| }; |
| |
| // Create an array without a validity buffer. |
| ArrayPtr MakeArrowArrayInt32WithNullValidity(std::vector<int32_t> in_data) { |
| auto array = MakeArrowArrayInt32(in_data); |
| return std::make_shared<arrow::Int32Array>(in_data.size(), array->data()->buffers[1], |
| nullptr, 0); |
| } |
| |
| TEST_F(TestNullValidity, TestFunc) { |
| // schema for input fields |
| auto field0 = field("f0", int32()); |
| auto field1 = field("f1", int32()); |
| auto schema = arrow::schema({field0, field1}); |
| |
| // Build condition f0 + f1 < 10 |
| auto node_f0 = TreeExprBuilder::MakeField(field0); |
| auto node_f1 = TreeExprBuilder::MakeField(field1); |
| auto sum_func = |
| TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32()); |
| auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10); |
| auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10}, |
| arrow::boolean()); |
| auto condition = TreeExprBuilder::MakeCondition(less_than_10); |
| |
| std::shared_ptr<Filter> filter; |
| auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); |
| EXPECT_TRUE(status.ok()); |
| |
| // Create a row-batch with some sample data |
| int num_records = 5; |
| |
| // Create an array without a validity buffer. |
| auto array0 = MakeArrowArrayInt32WithNullValidity({1, 2, 3, 4, 6}); |
| auto array1 = MakeArrowArrayInt32({5, 9, 6, 17, 3}, {true, true, false, true, true}); |
| // expected output (indices for which condition matches) |
| auto exp = MakeArrowArrayUint16({0, 4}); |
| |
| // prepare input record batch |
| auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); |
| |
| std::shared_ptr<SelectionVector> selection_vector; |
| status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector); |
| EXPECT_TRUE(status.ok()); |
| |
| // Evaluate expression |
| status = filter->Evaluate(*in_batch, selection_vector); |
| EXPECT_TRUE(status.ok()); |
| |
| // Validate results |
| EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); |
| } |
| |
| TEST_F(TestNullValidity, TestIfElse) { |
| // schema for input fields |
| auto fielda = field("a", int32()); |
| auto fieldb = field("b", int32()); |
| auto schema = arrow::schema({fielda, fieldb}); |
| |
| // output fields |
| auto field_result = field("res", int32()); |
| |
| // build expression. |
| // if (a > b) |
| // a |
| // else |
| // b |
| auto node_a = TreeExprBuilder::MakeField(fielda); |
| auto node_b = TreeExprBuilder::MakeField(fieldb); |
| auto condition = |
| TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); |
| auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32()); |
| |
| auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); |
| |
| // Build a projector for the expressions. |
| std::shared_ptr<Projector> projector; |
| auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); |
| EXPECT_TRUE(status.ok()); |
| |
| // Create a row-batch with some sample data |
| int num_records = 4; |
| auto array0 = MakeArrowArrayInt32WithNullValidity({10, 12, -20, 5}); |
| auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}); |
| |
| // expected output |
| auto exp = MakeArrowArrayInt32({10, 15, 15, 17}, {true, true, true, true}); |
| |
| // prepare input record batch |
| auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); |
| |
| // Evaluate expression |
| arrow::ArrayVector outputs; |
| status = projector->Evaluate(*in_batch, pool_, &outputs); |
| EXPECT_TRUE(status.ok()); |
| |
| // Validate results |
| EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); |
| } |
| |
| TEST_F(TestNullValidity, TestUtf8) { |
| // schema for input fields |
| auto field_a = field("a", utf8()); |
| auto schema = arrow::schema({field_a}); |
| |
| // output fields |
| auto res = field("res1", int32()); |
| |
| // build expressions. |
| // length(a) |
| auto expr = TreeExprBuilder::MakeExpression("length", {field_a}, res); |
| |
| // Build a projector for the expressions. |
| std::shared_ptr<Projector> projector; |
| auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); |
| EXPECT_TRUE(status.ok()) << status.message(); |
| |
| // Create a row-batch with some sample data |
| int num_records = 5; |
| auto array_v = MakeArrowArrayUtf8({"foo", "hello", "bye", "hi", "मदन"}); |
| auto array_a = std::make_shared<arrow::StringArray>( |
| num_records, array_v->data()->buffers[1], array_v->data()->buffers[2]); |
| |
| // expected output |
| auto exp = MakeArrowArrayInt32({3, 5, 3, 2, 3}, {true, true, true, true, true}); |
| |
| // prepare input record batch |
| auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); |
| |
| // Evaluate expression |
| arrow::ArrayVector outputs; |
| status = projector->Evaluate(*in_batch, pool_, &outputs); |
| EXPECT_TRUE(status.ok()); |
| |
| // Validate results |
| EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); |
| } |
| |
| } // namespace gandiva |