blob: 3a8a3a0eb85300599fb6634e24f7803a8d437b3d [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "arrow/compute/registry.h"
#include <algorithm>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <utility>
#include "arrow/compute/function.h"
#include "arrow/compute/registry_internal.h"
#include "arrow/status.h"
namespace arrow {
namespace compute {
class FunctionRegistry::FunctionRegistryImpl {
public:
Status AddFunction(std::shared_ptr<Function> function, bool allow_overwrite) {
RETURN_NOT_OK(function->Validate());
std::lock_guard<std::mutex> mutation_guard(lock_);
const std::string& name = function->name();
auto it = name_to_function_.find(name);
if (it != name_to_function_.end() && !allow_overwrite) {
return Status::KeyError("Already have a function registered with name: ", name);
}
name_to_function_[name] = std::move(function);
return Status::OK();
}
Status AddAlias(const std::string& target_name, const std::string& source_name) {
std::lock_guard<std::mutex> mutation_guard(lock_);
auto it = name_to_function_.find(source_name);
if (it == name_to_function_.end()) {
return Status::KeyError("No function registered with name: ", source_name);
}
name_to_function_[target_name] = it->second;
return Status::OK();
}
Result<std::shared_ptr<Function>> GetFunction(const std::string& name) const {
auto it = name_to_function_.find(name);
if (it == name_to_function_.end()) {
return Status::KeyError("No function registered with name: ", name);
}
return it->second;
}
std::vector<std::string> GetFunctionNames() const {
std::vector<std::string> results;
for (auto it : name_to_function_) {
results.push_back(it.first);
}
std::sort(results.begin(), results.end());
return results;
}
int num_functions() const { return static_cast<int>(name_to_function_.size()); }
private:
std::mutex lock_;
std::unordered_map<std::string, std::shared_ptr<Function>> name_to_function_;
};
std::unique_ptr<FunctionRegistry> FunctionRegistry::Make() {
return std::unique_ptr<FunctionRegistry>(new FunctionRegistry());
}
FunctionRegistry::FunctionRegistry() { impl_.reset(new FunctionRegistryImpl()); }
FunctionRegistry::~FunctionRegistry() {}
Status FunctionRegistry::AddFunction(std::shared_ptr<Function> function,
bool allow_overwrite) {
return impl_->AddFunction(std::move(function), allow_overwrite);
}
Status FunctionRegistry::AddAlias(const std::string& target_name,
const std::string& source_name) {
return impl_->AddAlias(target_name, source_name);
}
Result<std::shared_ptr<Function>> FunctionRegistry::GetFunction(
const std::string& name) const {
return impl_->GetFunction(name);
}
std::vector<std::string> FunctionRegistry::GetFunctionNames() const {
return impl_->GetFunctionNames();
}
int FunctionRegistry::num_functions() const { return impl_->num_functions(); }
namespace internal {
static std::unique_ptr<FunctionRegistry> CreateBuiltInRegistry() {
auto registry = FunctionRegistry::Make();
// Scalar functions
RegisterScalarArithmetic(registry.get());
RegisterScalarBoolean(registry.get());
RegisterScalarCast(registry.get());
RegisterScalarComparison(registry.get());
RegisterScalarNested(registry.get());
RegisterScalarSetLookup(registry.get());
RegisterScalarStringAscii(registry.get());
RegisterScalarValidity(registry.get());
RegisterScalarFillNull(registry.get());
// Vector functions
RegisterVectorHash(registry.get());
RegisterVectorSelection(registry.get());
RegisterVectorNested(registry.get());
RegisterVectorSort(registry.get());
// Aggregate functions
RegisterScalarAggregateBasic(registry.get());
RegisterScalarAggregateMode(registry.get());
RegisterScalarAggregateQuantile(registry.get());
RegisterScalarAggregateTDigest(registry.get());
RegisterScalarAggregateVariance(registry.get());
RegisterHashAggregateBasic(registry.get());
return registry;
}
} // namespace internal
FunctionRegistry* GetFunctionRegistry() {
static auto g_registry = internal::CreateBuiltInRegistry();
return g_registry.get();
}
} // namespace compute
} // namespace arrow