| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file codegen_arm.cc |
| * \brief ARM specific code generator |
| */ |
| #ifdef TVM_LLVM_VERSION |
| |
| #include <llvm/IR/Intrinsics.h> |
| #include <tvm/ffi/function.h> |
| #include <tvm/ffi/reflection/registry.h> |
| #if TVM_LLVM_VERSION >= 100 |
| #include <llvm/IR/IntrinsicsARM.h> |
| #endif |
| #include <llvm/Target/TargetMachine.h> |
| |
| #include "codegen_cpu.h" |
| |
| namespace tvm { |
| namespace codegen { |
| |
| // ARM specific code generator, this is used as an example on |
| // how to override behavior llvm code generator for specific target |
| class CodeGenARM final : public CodeGenCPU { |
| public: |
| CodeGenARM() = default; |
| virtual ~CodeGenARM() = default; |
| |
| void InitTarget() final { |
| // set native vector bits. |
| native_vector_bits_ = 16 * 8; |
| CodeGenCPU::InitTarget(); |
| } |
| llvm::Value* CreateIntrinsic(const CallNode* op) override; |
| |
| private: |
| PrimExpr ARMPopcount(const CallNode* op); |
| }; |
| |
| llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { |
| if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) { |
| llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value); |
| if (id == llvm::Intrinsic::ctpop) { |
| PrimExpr e = ARMPopcount(op); |
| return CodeGenCPU::CreateIntrinsic(e.as<CallNode>()); |
| } |
| } |
| return CodeGenCPU::CreateIntrinsic(op); |
| } |
| |
| PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { |
| using namespace tir; |
| const PrimExpr& e = call->args[1]; |
| llvm::Intrinsic::ID ctpop_id = llvm::Intrinsic::ctpop; |
| llvm::Intrinsic::ID vpaddlu_id = llvm::Intrinsic::arm_neon_vpaddlu; |
| |
| // Fallback to default llvm lowering rule if input type not a full vector or half vector length |
| int total_size = call->dtype.bits() * call->dtype.lanes(); |
| if (!call->dtype.is_fixed_length_vector() || call->dtype.bits() == 8 || |
| (total_size != 128 && total_size != 64)) { |
| ffi::Array<PrimExpr> vcnt_args; |
| vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); |
| vcnt_args.push_back(e); |
| return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args); |
| } |
| |
| // Popcount lowering rule: |
| // Reinterpret input vector as a vector of 8bit values and preform popcount |
| // Pairwise add between adjacent elements and double width with vpaddlu |
| // to return back to original input type |
| |
| // Dvisions are always divisible (number of bits = 64 or 128) |
| DataType uint8_type = DataType(e.dtype().code(), 8, e.dtype().bits() * e.dtype().lanes() / 8); |
| DataType uint16_type = |
| DataType(uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16); |
| DataType uint32_type = |
| DataType(uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32); |
| |
| // Interpret input as vector of 8bit values |
| PrimExpr input8 = reinterpret(uint8_type, e); |
| // Popcount 8bit->8bit |
| const CallNode* c0 = input8.as<CallNode>(); |
| TVM_FFI_ICHECK(c0 != nullptr); |
| ffi::Array<PrimExpr> vcnt8_args; |
| vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); |
| vcnt8_args.push_back(input8); |
| PrimExpr vcnt8 = tir::Call(uint8_type, builtin_call_llvm_pure_intrin_, vcnt8_args); |
| |
| // Accumulation 8->16bit |
| ffi::Array<PrimExpr> vcnt16_args; |
| vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); |
| vcnt16_args.push_back(vcnt8); |
| PrimExpr vcnt16 = tir::Call(uint16_type, builtin_call_llvm_pure_intrin_, vcnt16_args); |
| if (call->dtype.bits() == 16) { |
| return vcnt16; |
| } |
| |
| // Accumulation 16->32bit |
| ffi::Array<PrimExpr> vcnt32_args; |
| vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); |
| vcnt32_args.push_back(vcnt16); |
| PrimExpr vcnt32 = tir::Call(uint32_type, builtin_call_llvm_pure_intrin_, vcnt32_args); |
| if (call->dtype.bits() == 32) { |
| return vcnt32; |
| } |
| |
| // Accumulation 32->64bit |
| ffi::Array<PrimExpr> vcnt64_args; |
| vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); |
| vcnt64_args.push_back(vcnt32); |
| return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def_packed("tvm.codegen.llvm.target_arm", |
| [](const ffi::PackedArgs& targs, ffi::Any* rv) { |
| *rv = static_cast<void*>(new CodeGenARM()); |
| }); |
| } |
| |
| } // namespace codegen |
| } // namespace tvm |
| |
| #endif // TVM_LLVM_VERSION |