blob: d062963553045c40e9fa4197ddc3a559c8fe7ceb [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "gandiva/expression_registry.h"
#include "boost/iterator/transform_iterator.hpp"
#include "gandiva/function_registry.h"
#include "gandiva/llvm_types.h"
namespace gandiva {
ExpressionRegistry::ExpressionRegistry() {
function_registry_.reset(new FunctionRegistry());
}
ExpressionRegistry::~ExpressionRegistry() {}
const ExpressionRegistry::FunctionSignatureIterator
ExpressionRegistry::function_signature_begin() {
return FunctionSignatureIterator(function_registry_->begin());
}
const ExpressionRegistry::FunctionSignatureIterator
ExpressionRegistry::function_signature_end() const {
return FunctionSignatureIterator(function_registry_->end());
}
bool ExpressionRegistry::FunctionSignatureIterator::operator!=(
const FunctionSignatureIterator& func_sign_it) {
return func_sign_it.it_ != this->it_;
}
FunctionSignature ExpressionRegistry::FunctionSignatureIterator::operator*() {
return (*it_).signature();
}
ExpressionRegistry::iterator ExpressionRegistry::FunctionSignatureIterator::operator++(
int increment) {
return it_++;
}
DataTypeVector ExpressionRegistry::supported_types_ =
ExpressionRegistry::InitSupportedTypes();
DataTypeVector ExpressionRegistry::InitSupportedTypes() {
DataTypeVector data_type_vector;
llvm::LLVMContext llvm_context;
LLVMTypes llvm_types(llvm_context);
auto supported_arrow_types = llvm_types.GetSupportedArrowTypes();
for (auto& type_id : supported_arrow_types) {
AddArrowTypesToVector(type_id, data_type_vector);
}
return data_type_vector;
}
void ExpressionRegistry::AddArrowTypesToVector(arrow::Type::type& type,
DataTypeVector& vector) {
switch (type) {
case arrow::Type::type::BOOL:
vector.push_back(arrow::boolean());
break;
case arrow::Type::type::UINT8:
vector.push_back(arrow::uint8());
break;
case arrow::Type::type::INT8:
vector.push_back(arrow::int8());
break;
case arrow::Type::type::UINT16:
vector.push_back(arrow::uint16());
break;
case arrow::Type::type::INT16:
vector.push_back(arrow::int16());
break;
case arrow::Type::type::UINT32:
vector.push_back(arrow::uint32());
break;
case arrow::Type::type::INT32:
vector.push_back(arrow::int32());
break;
case arrow::Type::type::UINT64:
vector.push_back(arrow::uint64());
break;
case arrow::Type::type::INT64:
vector.push_back(arrow::int64());
break;
case arrow::Type::type::HALF_FLOAT:
vector.push_back(arrow::float16());
break;
case arrow::Type::type::FLOAT:
vector.push_back(arrow::float32());
break;
case arrow::Type::type::DOUBLE:
vector.push_back(arrow::float64());
break;
case arrow::Type::type::STRING:
vector.push_back(arrow::utf8());
break;
case arrow::Type::type::BINARY:
vector.push_back(arrow::binary());
break;
case arrow::Type::type::DATE32:
vector.push_back(arrow::date32());
break;
case arrow::Type::type::DATE64:
vector.push_back(arrow::date64());
break;
case arrow::Type::type::TIMESTAMP:
vector.push_back(arrow::timestamp(arrow::TimeUnit::SECOND));
vector.push_back(arrow::timestamp(arrow::TimeUnit::MILLI));
vector.push_back(arrow::timestamp(arrow::TimeUnit::NANO));
vector.push_back(arrow::timestamp(arrow::TimeUnit::MICRO));
break;
case arrow::Type::type::TIME32:
vector.push_back(arrow::time32(arrow::TimeUnit::SECOND));
vector.push_back(arrow::time32(arrow::TimeUnit::MILLI));
break;
case arrow::Type::type::TIME64:
vector.push_back(arrow::time64(arrow::TimeUnit::MICRO));
vector.push_back(arrow::time64(arrow::TimeUnit::NANO));
break;
case arrow::Type::type::NA:
vector.push_back(arrow::null());
break;
case arrow::Type::type::DECIMAL:
vector.push_back(arrow::decimal(38, 0));
break;
default:
// Unsupported types. test ensures that
// when one of these are added build breaks.
DCHECK(false);
}
}
std::vector<std::shared_ptr<FunctionSignature>> GetRegisteredFunctionSignatures() {
ExpressionRegistry registry;
std::vector<std::shared_ptr<FunctionSignature>> signatures;
for (auto iter = registry.function_signature_begin();
iter != registry.function_signature_end(); iter++) {
signatures.push_back(std::make_shared<FunctionSignature>(
(*iter).base_name(), (*iter).param_types(), (*iter).ret_type()));
}
return signatures;
}
} // namespace gandiva