| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file codegen_llvm.h |
| * \brief Common base class for generating into LLVM IR |
| */ |
| #ifndef TVM_TARGET_LLVM_CODEGEN_LLVM_H_ |
| #define TVM_TARGET_LLVM_CODEGEN_LLVM_H_ |
| |
| #ifdef TVM_LLVM_VERSION |
| |
| #include <llvm/ADT/ArrayRef.h> |
| #include <llvm/ADT/StringRef.h> |
| #include <llvm/IR/BasicBlock.h> |
| #include <llvm/IR/ConstantFolder.h> |
| #include <llvm/IR/Constants.h> |
| #include <llvm/IR/DerivedTypes.h> |
| #if TVM_LLVM_VERSION >= 150 |
| #include <llvm/IR/FMF.h> |
| #else |
| #include <llvm/IR/Operator.h> |
| #endif |
| #include <llvm/IR/GlobalValue.h> |
| #include <llvm/IR/IRBuilder.h> |
| #include <llvm/IR/Instructions.h> |
| #include <llvm/IR/Intrinsics.h> |
| #include <llvm/Support/Casting.h> |
| #if TVM_LLVM_VERSION >= 140 |
| #include <llvm/MC/TargetRegistry.h> |
| #else |
| #include <llvm/Support/TargetRegistry.h> |
| #endif |
| |
| #include <tvm/arith/analyzer.h> |
| #include <tvm/ir/module.h> |
| #include <tvm/target/codegen.h> |
| #include <tvm/tir/analysis.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/function.h> |
| #include <tvm/tir/op.h> |
| #include <tvm/tir/op_attr_types.h> |
| #include <tvm/tir/stmt.h> |
| #include <tvm/tir/stmt_functor.h> |
| |
| #include <algorithm> |
| #include <memory> |
| #include <string> |
| #include <unordered_map> |
| #include <unordered_set> |
| #include <utility> |
| #include <vector> |
| |
| #include "../../runtime/thread_storage_scope.h" |
| #include "../../tir/transforms/ir_utils.h" |
| #include "codegen_params.h" |
| |
| namespace llvm { |
| class Argument; |
| class CallInst; |
| class Function; |
| class GlobalVariable; |
| class Instruction; |
| class PassManagerBuilder; |
| class DIFile; |
| class DICompileUnit; |
| class MDNode; |
| |
| // Used in std::unique_ptr |
| class Module; |
| class DataLayout; |
| class DIBuilder; |
| class MDBuilder; |
| } // namespace llvm |
| |
| namespace tvm { |
| namespace codegen { |
| |
| class LLVMTarget; |
| |
| using namespace tir; |
| |
| /*! |
| * \brief A base class to generate a LLVM. |
| */ |
| class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>, |
| public StmtFunctor<void(const Stmt&)> { |
| public: |
| CodeGenLLVM(); // Do not make it default here. |
| virtual ~CodeGenLLVM(); // Do not make it default here. |
| |
| /*! |
| * \brief Create new code generator based on target machine. |
| * \param tm The target machine |
| * \return The created llvm generator. |
| */ |
| static std::unique_ptr<CodeGenLLVM> Create(LLVMTarget* llvm_target); |
| /*! |
| * \brief Initialize the code generator with given context |
| * \param module_name The name of the module. |
| * \param tm Target machine model |
| * \param ctx The context. |
| * \param system_lib Whether to insert system library registration. |
| * \param dynamic_lookup Whether dynamically lookup runtime function |
| * or use the runtime function table passed by caller. |
| * \param target_c_runtime If true, generate a module to be executed by the C runtime. In practice |
| * this option influences whether global ctors are used. |
| */ |
| virtual void Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, |
| bool dynamic_lookup, bool target_c_runtime); |
| |
| /*! |
| * \brief Turn on fast math flags for floating point operations. |
| * \param fmf FastMathFlags to use for code generation. |
| */ |
| void SetFastMathFlags(llvm::FastMathFlags fmf); |
| |
| /*! |
| * \brief Compile and add function f to the current module. |
| * \param f The function to be added. |
| */ |
| virtual void AddFunction(const PrimFunc& f); |
| /*! |
| * \brief Add main function as the entry name |
| * \param entry_func_name The name of entry function to be added. |
| */ |
| virtual void AddMainFunction(const std::string& entry_func_name); |
| /*! |
| * \brief Finish current pass of codegen, get the module. |
| * \return the created module. |
| */ |
| virtual std::unique_ptr<llvm::Module> Finish(); |
| /*! |
| * \brief Add functions from the (unordered) range to the current module in a deterministic order. |
| * The range consists of objects convertible to PrimFunc. |
| * \param begin The beginning of the range. |
| * \param end The end of the range. |
| * \param pfunc Converter function from the range element type to PrimFunc. |
| */ |
| template <typename IterType, typename ConvType> |
| void AddFunctionsOrdered(IterType begin, IterType end, ConvType pfunc); |
| /*! |
| * \brief Add functions from the (unordered) range of elements of type PrimFunc to the current |
| * module in a deterministic order. |
| * \param begin The beginning of the range. |
| * \param end The end of the range. |
| */ |
| template <typename IterType> |
| void AddFunctionsOrdered(IterType begin, IterType end) { |
| this->AddFunctionsOrdered(begin, end, [](auto f) { return f; }); |
| } |
| /*! |
| * \brief Add mod to be linked with the generated module |
| * \param mod The module to be linked. |
| */ |
| void AddLinkModule(std::unique_ptr<llvm::Module>&& mod); |
| /*! |
| * \brief Create Value for expression e |
| * \param e The expression to be created value for. |
| * \return created value. |
| */ |
| llvm::Value* MakeValue(const PrimExpr& e) { return VisitExpr(e); } |
| // Short hande code to get a constant int 32 |
| llvm::Constant* ConstInt32(int64_t value) const { |
| return llvm::ConstantInt::getSigned(t_int32_, value); |
| } |
| // override codegen |
| llvm::Value* VisitExpr_(const VarNode* op) override; |
| llvm::Value* VisitExpr_(const CastNode* op) override; |
| llvm::Value* VisitExpr_(const IntImmNode* op) override; |
| llvm::Value* VisitExpr_(const FloatImmNode* op) override; |
| llvm::Value* VisitExpr_(const StringImmNode* op) override; |
| llvm::Value* VisitExpr_(const AddNode* op) override; |
| llvm::Value* VisitExpr_(const SubNode* op) override; |
| llvm::Value* VisitExpr_(const MulNode* op) override; |
| llvm::Value* VisitExpr_(const DivNode* op) override; |
| llvm::Value* VisitExpr_(const ModNode* op) override; |
| llvm::Value* VisitExpr_(const MinNode* op) override; |
| llvm::Value* VisitExpr_(const MaxNode* op) override; |
| llvm::Value* VisitExpr_(const LTNode* op) override; |
| llvm::Value* VisitExpr_(const LENode* op) override; |
| llvm::Value* VisitExpr_(const GTNode* op) override; |
| llvm::Value* VisitExpr_(const GENode* op) override; |
| llvm::Value* VisitExpr_(const EQNode* op) override; |
| llvm::Value* VisitExpr_(const NENode* op) override; |
| llvm::Value* VisitExpr_(const AndNode* op) override; |
| llvm::Value* VisitExpr_(const OrNode* op) override; |
| llvm::Value* VisitExpr_(const NotNode* op) override; |
| llvm::Value* VisitExpr_(const SelectNode* op) override; |
| llvm::Value* VisitExpr_(const LetNode* op) override; |
| llvm::Value* VisitExpr_(const LoadNode* op) override; |
| llvm::Value* VisitExpr_(const BufferLoadNode* op) override; |
| llvm::Value* VisitExpr_(const CallNode* op) override; |
| llvm::Value* VisitExpr_(const RampNode* op) override; |
| llvm::Value* VisitExpr_(const ShuffleNode* op) override; |
| llvm::Value* VisitExpr_(const BroadcastNode* op) override; |
| // stmt |
| void VisitStmt_(const StoreNode* op) override; |
| void VisitStmt_(const BufferStoreNode* op) override; |
| void VisitStmt_(const ForNode* op) override; |
| void VisitStmt_(const WhileNode* op) override; |
| void VisitStmt_(const IfThenElseNode* op) override; |
| void VisitStmt_(const AllocateNode* op) override; |
| void VisitStmt_(const AllocateConstNode* op) override; |
| void VisitStmt_(const AttrStmtNode* op) override; |
| void VisitStmt_(const AssertStmtNode* op) override; |
| void VisitStmt_(const LetStmtNode* op) override; |
| void VisitStmt_(const SeqStmtNode* op) override; |
| void VisitStmt_(const EvaluateNode* op) override; |
| |
| // Get constant string |
| llvm::Constant* GetConstString(const std::string& str); |
| |
| llvm::Constant* GetGlobalConstant( |
| llvm::Constant* const_data, const std::string& name = "", |
| llvm::GlobalValue::LinkageTypes linkage_type = llvm::GlobalValue::InternalLinkage); |
| |
| protected: |
| /*! |
| * \brief Address and type pair to assist in handling opaque pointers. |
| */ |
| struct TypedPointer { |
| TypedPointer() = default; |
| TypedPointer(llvm::Type* t, llvm::Value* a) : type(t), addr(a) {} |
| llvm::Type* type = nullptr; /*!< Type of the value pointed to. */ |
| llvm::Value* addr = nullptr; /*!< Address of the value. */ |
| }; |
| /*! \brief The storage information */ |
| struct StorageInfo { |
| /*! \brief The alignment of allocation */ |
| int alignment{0}; |
| }; |
| /*! |
| * \brief Convert tvm::runtime::String into llvm::StringRef |
| */ |
| static llvm::StringRef MakeStringRef(const String& string) { |
| return llvm::StringRef(string.c_str(), string.size()); |
| } |
| /*! |
| * \brief Execute falloca at the beginning of the |
| * currrent function and obtain its return value. |
| * |
| * This is a helper function to make sure that |
| * alloca always happen in the beginning of the function. |
| * |
| * \param falloca The allocation function to be executed. |
| * \tparam F The function to be executed. |
| * \return The result. |
| */ |
| template <typename F> |
| llvm::AllocaInst* WithFunctionEntry(F falloca) { |
| llvm::BasicBlock* current = builder_->GetInsertBlock(); |
| llvm::BasicBlock* entry = &(function_->getEntryBlock()); |
| builder_->SetInsertPoint(entry, entry->begin()); |
| llvm::AllocaInst* res = falloca(); |
| builder_->SetInsertPoint(current); |
| return res; |
| } |
| // create intrinstic given call |
| virtual llvm::Value* CreateIntrinsic(const CallNode* op); |
| // create extern function call |
| // skip first arg mode used for call extern intrinsic. |
| virtual llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, |
| const Array<PrimExpr>& args, bool skip_first_arg); |
| |
| /*! \brief Insert a printf() call to the generated LLVM |
| * |
| * This is intended solely for debugging purposes. After calling |
| * printf(), immediately calls fflush() to flush the stdout buffer |
| * in case of segfault. |
| */ |
| virtual void CreatePrintf(const std::string& format, llvm::ArrayRef<llvm::Value*> format_args); |
| |
| /*! \brief Lookup return address, for debugging purposes |
| * |
| * This is intended solely for debugging purposes. Calls the |
| * `llvm::Intrinsic::returnaddress`, returning the return address of |
| * the current function call. |
| * |
| * \param level Look up the return address of a frame `level` steps |
| * above the current stack frame. |
| */ |
| llvm::Value* CreateLookupReturnAddress(unsigned int level = 0); |
| |
| // Get the corresponding thread index |
| virtual llvm::Value* GetThreadIndex(const IterVar& iv); |
| // Get the corresponding thread index |
| virtual llvm::Value* CreateStorageSync(const CallNode* op); |
| #if TVM_LLVM_VERSION < 160 |
| // This function only works with the legacy pass manager. |
| // apply optimization on the module. |
| virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder); |
| #endif |
| // Scalarize by iterating elements of e. |
| // f is a callback that takes index and v. |
| void Scalarize(const PrimExpr& e, std::function<void(int i, llvm::Value* v)> f); |
| |
| /* \brief Helper function for handling buffer access |
| * |
| * \param buffer The buffer being accessed |
| * |
| * \param indices The indices at which the buffer is being accessed. |
| * |
| * \param value_dtype The datatype to be read from (BufferLoad) or |
| * written to (BufferStore) the buffer. |
| * |
| * \param make_instruction A callback function that generates that |
| * actual call. |
| * |
| * - buffer_ptr: A typed pointer to the element being accessed |
| * |
| * - subelement_i: The index of a vectorized type to be |
| * stored/loaded. If -1, indicates that the entire type, |
| * vector or scalar, should be written. |
| * |
| * - alignment: The alignment to be used for the read/write. |
| * |
| * - is_volatile: Whether the read/write should be volatile. |
| * |
| * - Should return the generated expression. |
| */ |
| void BufferAccessHelper( |
| Buffer buffer, Array<PrimExpr> indices, DataType value_dtype, |
| std::function<llvm::Instruction*(TypedPointer buffer_ptr, int subelement_i, int alignment, |
| bool is_volatile)> |
| make_instruction); |
| // Initialize target |
| virtual void InitTarget(); |
| // Add module startup function if needed. |
| virtual void AddStartupFunction() {} |
| // apply optimization on the module. |
| virtual void Optimize(); |
| // Get the maximim storage align bits of buffer pointer given storage scope. |
| virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const; |
| // Get correct address space depending on the backend |
| virtual unsigned GetGlobalAddressSpace() const; |
| void AddFunctionInternal(const PrimFunc& f, bool ret_void); |
| // Create extern call |
| llvm::CallInst* CreateCallExtern(llvm::Type* ret, const std::string& name, |
| const std::vector<llvm::Value*>& value); |
| /*! |
| * \brief Get the LLVM Type for a given runtime type. |
| * \param dtype The runtime dtype. |
| * |
| * \note Only use this function for dealing with PrimTypes. |
| * For Call and Var that could have more refined types, |
| * use GetLLVMType instead. |
| * |
| * \return LLVM type of dtype |
| */ |
| llvm::Type* DTypeToLLVMType(const DataType& dtype) const; |
| /*! |
| * \brief Get the LLVM Type for a given type. |
| * \param dtype The runtime dtype. |
| * \param type The corresponding TVM Type. |
| */ |
| llvm::Type* GetLLVMType(const Type& type) const; |
| /*! |
| * \brief Get the LLVM Type for a given type. |
| * \param dtype The runtime dtype. |
| * \param type The corresponding TVM Type. |
| */ |
| llvm::Type* GetLLVMType(const PrimExpr& expr) const; |
| /*! |
| * \brief Get the declaration of the LLVM intrinsic based on the intrinsic |
| * id, and the type of the actual call. |
| * |
| * \param id The intrinsic id. |
| * \param ret_type The call return type. |
| * \param arg_types The types of the call arguments. |
| * |
| * \return Return the llvm::Function pointer, or nullptr if the declaration |
| * could not be generated (e.g. if the argument/return types do not |
| * match). |
| */ |
| llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type, |
| llvm::ArrayRef<llvm::Type*> arg_types); |
| /*! |
| * \brief Set target-related attributes on the LLVM function \p func. This |
| * includes "target-cpu" and "target-features" if present. |
| * |
| * \param func The function to set attributes on. |
| */ |
| void SetTargetAttributes(llvm::Function* func); |
| /*! |
| * \brief Emit LLVM IR for conversion functions __extendhfsf2 and __truncsfhf2 |
| * into the current llvm::Module. |
| * |
| * \param use_float16_abi Whether to use floating-point or integer ABI. |
| */ |
| void EmitFloat16ConversionBuiltins(bool use_float16_abi); |
| |
| /*! |
| * \brief Get the number of elements in the given vector value. |
| * \param vec The value, must be of a vector type. |
| */ |
| inline int GetVectorNumElements(llvm::Value* vec); |
| // initialize the function state. |
| void InitFuncState(); |
| // Get alignment given index. |
| void GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment, |
| int* p_native_bits); |
| // Returns whether the LLVM type has padding for alignment |
| bool HasAlignmentPadding(DataType dtype); |
| // do a scalarize call with f |
| llvm::Value* CreateScalarizedCall(const CallNode* op, llvm::Function* f, |
| const std::vector<llvm::Value*>& args); |
| // handle module import |
| void HandleImport(const std::string& code); |
| // cast operatpr |
| llvm::Value* CreateCast(DataType from, DataType to, llvm::Value* value); |
| // comparison op |
| llvm::Value* GetVarValue(const VarNode* v) const; |
| llvm::Value* CreateLT(DataType t, llvm::Value* a, llvm::Value* b); |
| llvm::Value* CreateLE(DataType t, llvm::Value* a, llvm::Value* b); |
| llvm::Value* CreateGT(DataType t, llvm::Value* a, llvm::Value* b); |
| llvm::Value* CreateGE(DataType t, llvm::Value* a, llvm::Value* b); |
| llvm::Value* CreateAdd(DataType t, llvm::Value* a, llvm::Value* b); |
| llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b); |
| llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b); |
| llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); |
| virtual TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, |
| llvm::ArrayRef<llvm::Value*> indices, DataType value_dtype); |
| // Vector concatenation. |
| llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent); |
| llvm::Value* CreateVecFlip(llvm::Value* vec); |
| llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs); |
| llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes); |
| // Create serial for |
| void CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, |
| const Var& loop_var, const Stmt& body); |
| // add alias information. |
| void AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer_var, PrimExpr index, |
| DataType access_dtype); |
| |
| llvm::GlobalVariable* AllocateSharedMemory(DataType dtype, size_t size, |
| unsigned int shared_address_space, int alignment, |
| llvm::GlobalValue::LinkageTypes linkage); |
| |
| /*! |
| * \brief Get the `i`th argument to the given function, respecting LLVM API changes. |
| * |
| * NOTE: in LLVM < 10.0, the underlying API returns a const llvm::Argument*. To provide a uniform |
| * API, const is removed here. Proper usage of LLVM APIs depends on having a non-const Argument*, |
| * so we take this appraoch here rather than adding const. |
| * |
| * \param function The function containing the arguments. |
| * \param i The index of the argument to retrieve. |
| * \return The retrieved argument. |
| */ |
| llvm::Argument* GetArg(const llvm::Function* function, int i) const { |
| #if TVM_LLVM_VERSION >= 100 |
| return function->getArg(i); |
| #elif TVM_LLVM_VERSION >= 50 |
| return const_cast<llvm::Argument*>(&function->arg_begin()[i]); |
| #else |
| return const_cast<llvm::Argument*>(&*std::next(function->arg_begin(), i)); |
| #endif |
| } |
| |
| // The IRBuilder. |
| using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>; |
| // The current function |
| llvm::Function* function_; |
| // Internal builder |
| std::unique_ptr<IRBuilder> builder_; |
| // The module to be returned; |
| std::unique_ptr<llvm::Module> module_; |
| std::unique_ptr<llvm::DataLayout> data_layout_; |
| // Internal metabuilder |
| std::unique_ptr<llvm::MDBuilder> md_builder_; |
| // llvm target info |
| LLVMTarget* llvm_target_{nullptr}; |
| // helpful data types |
| llvm::Type* t_void_{nullptr}; |
| llvm::PointerType* t_void_p_{nullptr}; |
| llvm::Type* t_int_{nullptr}; |
| llvm::Type* t_char_{nullptr}; |
| llvm::Type* t_int8_{nullptr}; |
| llvm::Type* t_int16_{nullptr}; |
| llvm::Type* t_int32_{nullptr}; |
| llvm::Type* t_int64_{nullptr}; |
| llvm::Type* t_float64_{nullptr}; |
| // meta data |
| llvm::MDNode* md_very_likely_branch_{nullptr}; |
| llvm::MDNode* md_tbaa_root_{nullptr}; |
| llvm::MDNode* md_tbaa_alias_set_{nullptr}; |
| // modules to be linked. |
| std::vector<std::unique_ptr<llvm::Module>> link_modules_; |
| /*! \brief native vector bits of current targetx*/ |
| int native_vector_bits_{0}; |
| /*! \brief the storage scope of allocation */ |
| std::unordered_map<const VarNode*, StorageInfo> alloc_storage_info_; |
| // The definition of local variable. |
| std::unordered_map<const VarNode*, llvm::Value*> var_map_; |
| // global strings |
| std::unordered_map<std::string, llvm::Constant*> str_map_; |
| // Whether current function is restricted |
| bool is_restricted_{true}; |
| // The analyzer information |
| std::unique_ptr<arith::Analyzer> analyzer_; |
| // set of var that are not restricted(can alias) |
| std::unordered_set<const VarNode*> alias_var_set_; |
| // set of volatile buffer. |
| std::unordered_set<const VarNode*> volatile_buf_; |
| // deep comparison of PrimExpr |
| ExprDeepEqual deep_equal_; |
| // binding of let variables. Enables duplicate var defs that map to same value |
| std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_; |
| // Cache potential common path ops to slightly improve lookup time. |
| // global symbol table. |
| OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol"); |
| const Op& builtin_call_extern_ = builtin::call_extern(); |
| const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); |
| const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin(); |
| const Op& builtin_call_llvm_pure_intrin_ = builtin::call_llvm_pure_intrin(); |
| const Op& builtin_lookup_param_ = builtin::lookup_param(); |
| const Op& builtin_tvm_call_cpacked_lowered_ = builtin::tvm_call_cpacked_lowered(); |
| |
| /*! \brief Helper struct for debug infos. */ |
| struct DebugInfo { |
| ~DebugInfo(); // Because of the std::unique_ptr. |
| std::unique_ptr<llvm::DIBuilder> di_builder_; |
| llvm::DICompileUnit* compilation_unit_{nullptr}; |
| llvm::DIFile* file_{nullptr}; |
| }; |
| /*! |
| * \brief Create a new DebugInfo struct from the given Module that |
| * initializes file and compilation_unit_ to TVM defaults. |
| */ |
| static std::unique_ptr<DebugInfo> CreateDebugInfo(llvm::Module* module); |
| }; |
| |
| inline int CodeGenLLVM::GetVectorNumElements(llvm::Value* vec) { |
| #if TVM_LLVM_VERSION >= 120 |
| return llvm::cast<llvm::FixedVectorType>(vec->getType())->getNumElements(); |
| #else |
| return llvm::cast<llvm::VectorType>(vec->getType())->getNumElements(); |
| #endif |
| } |
| |
| template <typename IterType, typename ConvType> |
| void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfunc) { |
| std::vector<PrimFunc> funcs; |
| for (auto it = begin; it != end; ++it) { |
| funcs.push_back(pfunc(*it)); |
| } |
| std::sort(funcs.begin(), funcs.end(), [](PrimFunc func_a, PrimFunc func_b) { |
| std::string name_a = func_a->GetAttr<String>(tvm::attr::kGlobalSymbol).value(); |
| std::string name_b = func_b->GetAttr<String>(tvm::attr::kGlobalSymbol).value(); |
| return name_a < name_b; |
| }); |
| for (auto& f : funcs) { |
| auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); |
| AddFunction(f); |
| } |
| } |
| |
| } // namespace codegen |
| } // namespace tvm |
| |
| #endif // TVM_LLVM_VERSION |
| #endif // TVM_TARGET_LLVM_CODEGEN_LLVM_H_ |