blob: fccc92a228308ee9d6c38f31c30b68236f41cb51 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_params.cc
*/
#ifdef TVM_LLVM_VERSION
#include "codegen_params.h"
#include <llvm/ADT/ArrayRef.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/Support/Casting.h>
#include <algorithm>
#include <type_traits>
#include <vector>
namespace tvm {
namespace codegen {
template <typename T, typename E = void>
struct LLVMConstantGetter {
static llvm::Constant* getElement(llvm::Type* ty, T t);
};
template <typename T>
struct LLVMConstantGetter<
T, std::enable_if_t<(std::is_integral<T>::value && std::is_signed<T>::value)>> {
static llvm::Constant* getElement(llvm::Type* ty, T t) {
return llvm::ConstantInt::getSigned(ty, t);
}
};
template <typename T>
struct LLVMConstantGetter<
T, std::enable_if_t<(std::is_integral<T>::value && !std::is_signed<T>::value)>> {
static llvm::Constant* getElement(llvm::Type* ty, T t) { return llvm::ConstantInt::get(ty, t); }
};
template <typename T>
struct LLVMConstantGetter<T, std::enable_if_t<std::is_floating_point<T>::value>> {
static llvm::Constant* getElement(llvm::Type* ty, T t) { return llvm::ConstantFP::get(ty, t); }
};
template <typename T, typename = std::enable_if<std::is_pod<T>::value>>
void BuildLLVMVector(llvm::Type* element_type, void* tensor_data, size_t num_elements,
std::vector<llvm::Constant*>* elements) {
elements->resize(num_elements, nullptr);
std::transform(static_cast<T*>(tensor_data), static_cast<T*>(tensor_data) + num_elements,
elements->begin(),
[&](T t) { return LLVMConstantGetter<T>::getElement(element_type, t); });
}
llvm::ConstantArray* TensorToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::Tensor arr) {
llvm::Type* element_type = nullptr;
auto arr_type = arr.DataType();
TVM_FFI_ICHECK(arr.IsContiguous()) << "CodegenParams: only support contiguous arrays";
TVM_FFI_ICHECK_EQ(arr->device.device_type, kDLCPU)
<< "CodegenParams: only support contiguous arrays";
TVM_FFI_ICHECK_EQ(arr_type.lanes(), 1)
<< "CodegenParams: only support generating 1-lane parameters; saw " << arr_type.lanes();
auto shape = arr.Shape();
int num_elements = 1;
for (auto shape_elem : shape) {
num_elements *= shape_elem;
}
std::vector<llvm::Constant*> elements;
switch (arr_type.code()) {
case runtime::DataType::kInt:
TVM_FFI_ICHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 ||
arr_type.bits() == 64)
<< "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw "
<< arr_type.bits() << "-bit array";
element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits());
switch (arr_type.bits()) {
case 8:
BuildLLVMVector<int8_t>(element_type, arr->data, num_elements, &elements);
break;
case 16:
BuildLLVMVector<int16_t>(element_type, arr->data, num_elements, &elements);
break;
case 32:
BuildLLVMVector<int32_t>(element_type, arr->data, num_elements, &elements);
break;
case 64:
BuildLLVMVector<int64_t>(element_type, arr->data, num_elements, &elements);
break;
default:
TVM_FFI_ICHECK(false) << "should not get here";
break;
}
break;
case runtime::DataType::TypeCode::kUInt:
TVM_FFI_ICHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 ||
arr_type.bits() == 64)
<< "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw "
<< arr_type.bits() << "-bit array";
element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits());
switch (arr_type.bits()) {
case 8:
BuildLLVMVector<uint8_t>(element_type, arr->data, num_elements, &elements);
break;
case 16:
BuildLLVMVector<uint16_t>(element_type, arr->data, num_elements, &elements);
break;
case 32:
BuildLLVMVector<uint32_t>(element_type, arr->data, num_elements, &elements);
break;
case 64:
BuildLLVMVector<uint64_t>(element_type, arr->data, num_elements, &elements);
break;
default:
TVM_FFI_ICHECK(false) << "should not get here";
break;
}
break;
case runtime::DataType::TypeCode::kFloat:
switch (arr_type.bits()) {
case 16:
// NOTE: float16 is treated as uint16_t.
element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits());
BuildLLVMVector<uint16_t>(element_type, arr->data, num_elements, &elements);
break;
case 32:
element_type = llvm::Type::getFloatTy(*ctx);
BuildLLVMVector<float>(element_type, arr->data, num_elements, &elements);
break;
case 64:
element_type = llvm::Type::getDoubleTy(*ctx);
BuildLLVMVector<double>(element_type, arr->data, num_elements, &elements);
break;
default:
TVM_FFI_ICHECK(false) << "CodegenParams: only support 32- or 64-bit floating point; saw "
<< arr_type.bits() << "-bit array";
break;
}
break;
case runtime::DataType::TypeCode::kBFloat:
TVM_FFI_ICHECK(arr_type.bits() == 16)
<< "CodegenParams: only support 16-bit bfloat; saw " << arr_type.bits() << "-bit array";
element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits());
BuildLLVMVector<uint16_t>(element_type, arr->data, num_elements, &elements);
default:
TVM_FFI_ICHECK(false) << "Data type not supported";
}
return llvm::cast<llvm::ConstantArray>(llvm::ConstantArray::get(
llvm::ArrayType::get(element_type, num_elements), llvm::ArrayRef<llvm::Constant*>(elements)));
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION