blob: 9f19fa1f47f0f934e9eb14e33c7adf79f10178cf [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_arm.cc
* \brief ARM specific code generator
*/
#ifdef TVM_LLVM_VERSION
#include "codegen_cpu.h"
namespace tvm {
namespace codegen {
// ARM specific code generator, this is used as an example on
// how to override behavior llvm code generator for specific target
class CodeGenARM final : public CodeGenCPU {
public:
void InitTarget(llvm::TargetMachine* tm) final {
// set native vector bits.
native_vector_bits_ = 16 * 8;
CodeGenCPU::InitTarget(tm);
}
llvm::Value* CreateIntrinsic(const Call* op) override;
private:
Expr ARMPopcount(const Call* op);
};
llvm::Value* CodeGenARM::CreateIntrinsic(const Call* op) {
if (op->is_intrinsic("llvm_intrin")) {
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
op->args[0].as<UIntImm>()->value);
if (id == ::llvm::Intrinsic::ctpop) {
Expr e = ARMPopcount(op);
return CodeGenCPU::CreateIntrinsic(e.as<Call>());
}
}
return CodeGenCPU::CreateIntrinsic(op);
}
Expr CodeGenARM::ARMPopcount(const Call *call) {
using namespace ir;
const Expr& e = call->args[2];
::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu;
// Fallback to default llvm lowering rule if input type not a full vector or half vector length
int total_size = call->type.bits() * call->type.lanes();
if (!call->type.is_vector() || call->type.bits() == 8 ||
(total_size != 128 && total_size != 64)) {
Array<Expr> vcnt_args;
vcnt_args.push_back(ir::UIntImm::make(UInt(32), ctpop_id));
vcnt_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt_args.push_back(e);
return ir::Call::make(call->type, "llvm_intrin", vcnt_args, Call::PureIntrinsic);
}
// Popcount lowering rule:
// Reinterpret input vector as a vector of 8bit values and preform popcount
// Pairwise add between adjacent elements and double width with vpaddlu
// to return back to original input type
// Dvisions are always divisible (number of bits = 64 or 128)
Type uint8_type = Type(e.type().code(), 8, e.type().bits() * e.type().lanes() / 8);
Type uint16_type = Type(uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16);
Type uint32_type = Type(uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32);
// Interpret input as vector of 8bit values
Expr input8 = reinterpret(uint8_type, e);
// Popcount 8bit->8bit
const Call* c0 = input8.as<Call>();
CHECK(c0 != nullptr);
Array<Expr> vcnt8_args;
vcnt8_args.push_back(ir::UIntImm::make(UInt(32), ctpop_id));
vcnt8_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt8_args.push_back(input8);
Expr vcnt8 = ir::Call::make(uint8_type, "llvm_intrin", vcnt8_args, Call::PureIntrinsic);
// Accumulation 8->16bit
Array<Expr> vcnt16_args;
vcnt16_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
vcnt16_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt16_args.push_back(vcnt8);
Expr vcnt16 = ir::Call::make(uint16_type, "llvm_intrin", vcnt16_args, Call::PureIntrinsic);
if (call->type.bits() == 16) {
return vcnt16;
}
// Accumulation 16->32bit
Array<Expr> vcnt32_args;
vcnt32_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
vcnt32_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt32_args.push_back(vcnt16);
Expr vcnt32 = ir::Call::make(uint32_type, "llvm_intrin", vcnt32_args, Call::PureIntrinsic);
if (call->type.bits() == 32) {
return vcnt32;
}
// Accumulation 32->64bit
Array<Expr> vcnt64_args;
vcnt64_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
vcnt64_args.push_back(ir::UIntImm::make(UInt(32), 1));
vcnt64_args.push_back(vcnt32);
return ir::Call::make(call->type, "llvm_intrin", vcnt64_args, Call::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
CodeGenLLVM* cg = new CodeGenARM();
*rv = static_cast<void*>(cg);
});
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION