blob: f74bb245d770a1ea3b1ec41bc160897d5a3c2c87 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "arrow/compute/function.h"
#include <cstddef>
#include <memory>
#include <sstream>
#include "arrow/compute/cast.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/exec_internal.h"
#include "arrow/compute/kernels/common.h"
#include "arrow/datum.h"
#include "arrow/util/cpu_info.h"
namespace arrow {
using internal::checked_cast;
namespace compute {
static const FunctionDoc kEmptyFunctionDoc{};
const FunctionDoc& FunctionDoc::Empty() { return kEmptyFunctionDoc; }
static Status CheckArityImpl(const Function* function, int passed_num_args,
const char* passed_num_args_label) {
if (function->arity().is_varargs && passed_num_args < function->arity().num_args) {
return Status::Invalid("VarArgs function ", function->name(), " needs at least ",
function->arity().num_args, " arguments but ",
passed_num_args_label, " only ", passed_num_args);
}
if (!function->arity().is_varargs && passed_num_args != function->arity().num_args) {
return Status::Invalid("Function ", function->name(), " accepts ",
function->arity().num_args, " arguments but ",
passed_num_args_label, " ", passed_num_args);
}
return Status::OK();
}
Status Function::CheckArity(const std::vector<InputType>& in_types) const {
return CheckArityImpl(this, static_cast<int>(in_types.size()), "kernel accepts");
}
Status Function::CheckArity(const std::vector<ValueDescr>& descrs) const {
return CheckArityImpl(this, static_cast<int>(descrs.size()),
"attempted to look up kernel(s) with");
}
namespace detail {
Status NoMatchingKernel(const Function* func, const std::vector<ValueDescr>& descrs) {
return Status::NotImplemented("Function ", func->name(),
" has no kernel matching input types ",
ValueDescr::ToString(descrs));
}
template <typename KernelType>
const KernelType* DispatchExactImpl(const std::vector<KernelType*>& kernels,
const std::vector<ValueDescr>& values) {
const KernelType* kernel_matches[SimdLevel::MAX] = {nullptr};
// Validate arity
for (const auto& kernel : kernels) {
if (kernel->signature->MatchesInputs(values)) {
kernel_matches[kernel->simd_level] = kernel;
}
}
// Dispatch as the CPU feature
#if defined(ARROW_HAVE_RUNTIME_AVX512) || defined(ARROW_HAVE_RUNTIME_AVX2)
auto cpu_info = arrow::internal::CpuInfo::GetInstance();
#endif
#if defined(ARROW_HAVE_RUNTIME_AVX512)
if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) {
if (kernel_matches[SimdLevel::AVX512]) {
return kernel_matches[SimdLevel::AVX512];
}
}
#endif
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) {
if (kernel_matches[SimdLevel::AVX2]) {
return kernel_matches[SimdLevel::AVX2];
}
}
#endif
if (kernel_matches[SimdLevel::NONE]) {
return kernel_matches[SimdLevel::NONE];
}
return nullptr;
}
const Kernel* DispatchExactImpl(const Function* func,
const std::vector<ValueDescr>& values) {
if (func->kind() == Function::SCALAR) {
return DispatchExactImpl(checked_cast<const ScalarFunction*>(func)->kernels(),
values);
}
if (func->kind() == Function::VECTOR) {
return DispatchExactImpl(checked_cast<const VectorFunction*>(func)->kernels(),
values);
}
if (func->kind() == Function::SCALAR_AGGREGATE) {
return DispatchExactImpl(
checked_cast<const ScalarAggregateFunction*>(func)->kernels(), values);
}
if (func->kind() == Function::HASH_AGGREGATE) {
return DispatchExactImpl(checked_cast<const HashAggregateFunction*>(func)->kernels(),
values);
}
return nullptr;
}
} // namespace detail
Result<const Kernel*> Function::DispatchExact(
const std::vector<ValueDescr>& values) const {
if (kind_ == Function::META) {
return Status::NotImplemented("Dispatch for a MetaFunction's Kernels");
}
RETURN_NOT_OK(CheckArity(values));
if (auto kernel = detail::DispatchExactImpl(this, values)) {
return kernel;
}
return detail::NoMatchingKernel(this, values);
}
Result<const Kernel*> Function::DispatchBest(std::vector<ValueDescr>* values) const {
// TODO(ARROW-11508) permit generic conversions here
return DispatchExact(*values);
}
Result<Datum> Function::Execute(const std::vector<Datum>& args,
const FunctionOptions* options, ExecContext* ctx) const {
if (options == nullptr) {
options = default_options();
}
if (ctx == nullptr) {
ExecContext default_ctx;
return Execute(args, options, &default_ctx);
}
// type-check Datum arguments here. Really we'd like to avoid this as much as
// possible
RETURN_NOT_OK(detail::CheckAllValues(args));
std::vector<ValueDescr> inputs(args.size());
for (size_t i = 0; i != args.size(); ++i) {
inputs[i] = args[i].descr();
}
ARROW_ASSIGN_OR_RAISE(auto kernel, DispatchBest(&inputs));
ARROW_ASSIGN_OR_RAISE(auto implicitly_cast_args, Cast(args, inputs, ctx));
std::unique_ptr<KernelState> state;
KernelContext kernel_ctx{ctx};
if (kernel->init) {
ARROW_ASSIGN_OR_RAISE(state, kernel->init(&kernel_ctx, {kernel, inputs, options}));
kernel_ctx.SetState(state.get());
}
std::unique_ptr<detail::KernelExecutor> executor;
if (kind() == Function::SCALAR) {
executor = detail::KernelExecutor::MakeScalar();
} else if (kind() == Function::VECTOR) {
executor = detail::KernelExecutor::MakeVector();
} else if (kind() == Function::SCALAR_AGGREGATE) {
executor = detail::KernelExecutor::MakeScalarAggregate();
} else {
return Status::NotImplemented("Direct execution of HASH_AGGREGATE functions");
}
RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, inputs, options}));
auto listener = std::make_shared<detail::DatumAccumulator>();
RETURN_NOT_OK(executor->Execute(implicitly_cast_args, listener.get()));
return executor->WrapResults(implicitly_cast_args, listener->values());
}
Status Function::Validate() const {
if (!doc_->summary.empty()) {
// Documentation given, check its contents
int arg_count = static_cast<int>(doc_->arg_names.size());
if (arg_count == arity_.num_args) {
return Status::OK();
}
if (arity_.is_varargs && arg_count == arity_.num_args + 1) {
return Status::OK();
}
return Status::Invalid("In function '", name_,
"': ", "number of argument names != function arity");
}
return Status::OK();
}
Status ScalarFunction::AddKernel(std::vector<InputType> in_types, OutputType out_type,
ArrayKernelExec exec, KernelInit init) {
RETURN_NOT_OK(CheckArity(in_types));
if (arity_.is_varargs && in_types.size() != 1) {
return Status::Invalid("VarArgs signatures must have exactly one input type");
}
auto sig =
KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs);
kernels_.emplace_back(std::move(sig), exec, init);
return Status::OK();
}
Status ScalarFunction::AddKernel(ScalarKernel kernel) {
RETURN_NOT_OK(CheckArity(kernel.signature->in_types()));
if (arity_.is_varargs && !kernel.signature->is_varargs()) {
return Status::Invalid("Function accepts varargs but kernel signature does not");
}
kernels_.emplace_back(std::move(kernel));
return Status::OK();
}
Status VectorFunction::AddKernel(std::vector<InputType> in_types, OutputType out_type,
ArrayKernelExec exec, KernelInit init) {
RETURN_NOT_OK(CheckArity(in_types));
if (arity_.is_varargs && in_types.size() != 1) {
return Status::Invalid("VarArgs signatures must have exactly one input type");
}
auto sig =
KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs);
kernels_.emplace_back(std::move(sig), exec, init);
return Status::OK();
}
Status VectorFunction::AddKernel(VectorKernel kernel) {
RETURN_NOT_OK(CheckArity(kernel.signature->in_types()));
if (arity_.is_varargs && !kernel.signature->is_varargs()) {
return Status::Invalid("Function accepts varargs but kernel signature does not");
}
kernels_.emplace_back(std::move(kernel));
return Status::OK();
}
Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) {
RETURN_NOT_OK(CheckArity(kernel.signature->in_types()));
if (arity_.is_varargs && !kernel.signature->is_varargs()) {
return Status::Invalid("Function accepts varargs but kernel signature does not");
}
kernels_.emplace_back(std::move(kernel));
return Status::OK();
}
Status HashAggregateFunction::AddKernel(HashAggregateKernel kernel) {
RETURN_NOT_OK(CheckArity(kernel.signature->in_types()));
if (arity_.is_varargs && !kernel.signature->is_varargs()) {
return Status::Invalid("Function accepts varargs but kernel signature does not");
}
kernels_.emplace_back(std::move(kernel));
return Status::OK();
}
Result<Datum> MetaFunction::Execute(const std::vector<Datum>& args,
const FunctionOptions* options,
ExecContext* ctx) const {
RETURN_NOT_OK(
CheckArityImpl(this, static_cast<int>(args.size()), "attempted to Execute with"));
if (options == nullptr) {
options = default_options();
}
return ExecuteImpl(args, options, ctx);
}
} // namespace compute
} // namespace arrow