blob: bbe72c0ee970c5def20f9ff191f957a9895dbdc7 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "gandiva/function_registry.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <string>
#include <unordered_set>
#include "gandiva/tests/test_util.h"
namespace gandiva {
class TestFunctionRegistry : public ::testing::Test {
protected:
std::shared_ptr<FunctionRegistry> registry_ = gandiva::default_function_registry();
static std::unique_ptr<FunctionRegistry> MakeFunctionRegistryWithExternalFunction() {
auto registry = std::make_unique<FunctionRegistry>();
ARROW_EXPECT_OK(
registry->Register({GetTestExternalFunction()}, GetTestFunctionLLVMIRPath()));
return registry;
}
};
TEST_F(TestFunctionRegistry, TestFound) {
FunctionSignature add_i32_i32("add", {arrow::int32(), arrow::int32()}, arrow::int32());
const NativeFunction* function = registry_->LookupSignature(add_i32_i32);
EXPECT_NE(function, nullptr);
EXPECT_THAT(function->signatures(), testing::Contains(add_i32_i32));
EXPECT_EQ(function->pc_name(), "add_int32_int32");
}
TEST_F(TestFunctionRegistry, TestNotFound) {
FunctionSignature addX_i32_i32("addX", {arrow::int32(), arrow::int32()},
arrow::int32());
EXPECT_EQ(registry_->LookupSignature(addX_i32_i32), nullptr);
FunctionSignature add_i32_i32_ret64("add", {arrow::int32(), arrow::int32()},
arrow::int64());
EXPECT_EQ(registry_->LookupSignature(add_i32_i32_ret64), nullptr);
}
TEST_F(TestFunctionRegistry, TestCustomFunctionRegistry) {
auto registry = MakeFunctionRegistryWithExternalFunction();
auto multiply_by_two_func = GetTestExternalFunction();
auto multiply_by_two_int32_ret64 = multiply_by_two_func.signatures().front();
EXPECT_NE(registry->LookupSignature(multiply_by_two_int32_ret64), nullptr);
FunctionSignature add_i32_i32_ret64("add", {arrow::int32(), arrow::int32()},
arrow::int64());
EXPECT_EQ(registry->LookupSignature(add_i32_i32_ret64), nullptr);
}
TEST_F(TestFunctionRegistry, TestGetBitcodeMemoryBuffersDefaultFunctionRegistry) {
EXPECT_EQ(registry_->GetBitcodeBuffers().size(), 0);
}
TEST_F(TestFunctionRegistry, TestGetBitcodeMemoryBuffersCustomFunctionRegistry) {
auto registry = MakeFunctionRegistryWithExternalFunction();
EXPECT_EQ(registry->GetBitcodeBuffers().size(), 1);
}
// one nativefunction object per precompiled function
TEST_F(TestFunctionRegistry, TestNoDuplicates) {
std::unordered_set<std::string> pc_func_sigs;
std::unordered_set<std::string> native_func_duplicates;
std::unordered_set<std::string> func_sigs;
std::unordered_set<std::string> func_sig_duplicates;
for (const auto& native_func_it : *registry_) {
auto& first_sig = native_func_it.signatures().front();
auto pc_func_sig = FunctionSignature(native_func_it.pc_name(),
first_sig.param_types(), first_sig.ret_type())
.ToString();
if (pc_func_sigs.count(pc_func_sig) == 0) {
pc_func_sigs.insert(pc_func_sig);
} else {
native_func_duplicates.insert(pc_func_sig);
}
for (auto& sig : native_func_it.signatures()) {
auto sig_str = sig.ToString();
if (func_sigs.count(sig_str) == 0) {
func_sigs.insert(sig_str);
} else {
func_sig_duplicates.insert(sig_str);
}
}
}
std::ostringstream stream;
std::copy(native_func_duplicates.begin(), native_func_duplicates.end(),
std::ostream_iterator<std::string>(stream, "\n"));
std::string result = stream.str();
EXPECT_TRUE(native_func_duplicates.empty())
<< "Registry has duplicates.\nMultiple NativeFunction objects refer to the "
"following precompiled functions:\n"
<< result;
stream.clear();
std::copy(func_sig_duplicates.begin(), func_sig_duplicates.end(),
std::ostream_iterator<std::string>(stream, "\n"));
EXPECT_TRUE(func_sig_duplicates.empty())
<< "The following signatures are defined more than once possibly pointing to "
"different precompiled functions:\n"
<< stream.str();
}
} // namespace gandiva