| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file tvm/ir_pass.h |
| * \brief Collection of IR pass functions |
| * |
| * When the pass functions in this file are for Stmt, |
| * we can use PassFunction(Evaluate(expr)) to apply it to Expr |
| */ |
| #ifndef TVM_IR_PASS_H_ |
| #define TVM_IR_PASS_H_ |
| |
| #include <arithmetic/Simplify.h> |
| #include <unordered_map> |
| #include <vector> |
| #include <string> |
| #include "expr.h" |
| #include "buffer.h" |
| #include "schedule.h" |
| #include "lowered_func.h" |
| |
| namespace tvm { |
| namespace ir { |
| |
| /*! |
| * \brief Simplify the expression. |
| * \param expr The expression to be simplifed. |
| * \param vrange The range information about the variable. |
| * \return Canonicalized statement. |
| */ |
| TVM_DLL Expr Simplify(Expr expr, Map<Var, Range> vrange = Map<Var, Range>()); |
| |
| /*! |
| * \brief Simplify the statement. |
| * \param stmt The statement to be simplifed. |
| * \param vrange The range information about the variable. |
| * \return Canonicalized statement. |
| */ |
| Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>()); |
| |
| /*! |
| * \brief Simplify by applying canonical form. |
| * \param stmt The statement to be canonically simplifed. |
| * \param vrange The range information about the variable. |
| * \return Canonicalized statement. |
| */ |
| Stmt CanonicalSimplify(Stmt stmt, |
| Map<Var, Range> vrange = Map<Var, Range>()); |
| |
| /*! |
| * \brief Simplify by applying canonical form. |
| * \param expr The statement to be canonically simplifed. |
| * \param vrange The range information about the variable. |
| * \return Canonicalized expression. |
| */ |
| TVM_DLL Expr CanonicalSimplify(Expr expr, |
| Map<Var, Range> vrange = Map<Var, Range>()); |
| |
| /*! |
| * \brief Deep compare lhs and rhs |
| * \param lhs The left operand |
| * \param rhs The right operand |
| * \return The comparison result. |
| */ |
| TVM_DLL bool Equal(const Expr& lhs, const Expr& rhs); |
| |
| /*! |
| * \brief Deep compare lhs and rhs |
| * \param lhs The left operand |
| * \param rhs The right operand |
| * \return The comparison result. |
| */ |
| bool Equal(const Stmt& lhs, const Stmt& rhs); |
| |
| /*! |
| * \brief Deep compare lhs and rhs. |
| * |
| * If you only want equality comparison, use Equal |
| * which will also tie definitions. The compare mode |
| * will give order of expression in total order. |
| * |
| * \param lhs The left operand |
| * \param rhs The right operand |
| * \return The comparison result. |
| */ |
| int Compare(const Expr& lhs, const Expr& rhs); |
| |
| /*! |
| * \brief verifies whether the IR stmt or Expr is in SSA form. |
| * That is: each VarExpr is defined and assigned once(in Let/For) |
| * |
| * \param ir The root of the IR DAG. |
| * \return Whether IR is in SSA form. |
| * \note All the passes in this file uses SSA form and outputs SSA form. |
| */ |
| TVM_DLL bool VerifySSA(const Stmt& ir); |
| |
| /*! |
| * \brief Whether the expression have side effect. |
| * \return whether expression have side effect |
| */ |
| TVM_DLL bool HasSideEffect(const Expr& e); |
| |
| /*! |
| * \brief Whether e expression used var. |
| * \param e The expression to be checked. |
| * \param v The variable. |
| * \return Whether e uses v. |
| */ |
| bool ExprUseVar(const Expr& e, const Var& v); |
| |
| /*! |
| * \brief Whether e expression used any var in variable set.. |
| * \param e The expression to be checked. |
| * \param vset The variable set. |
| * \return Whether e uses vset. |
| */ |
| bool ExprUseVar(const Expr& e, const std::unordered_set<const Variable*>& vset); |
| |
| /*! |
| * \brief Convert a IR node to be SSA form. |
| * \param stmt The source statement to be converted. |
| * \return The converted form. |
| */ |
| TVM_DLL Stmt ConvertSSA(Stmt stmt); |
| |
| /*! |
| * \brief Substitute the var specified in key->var to be value. |
| * \param stmt The source statement to be substituted |
| * \param value_map The map of new values. |
| * \return The converted form. |
| */ |
| Stmt Substitute(Stmt stmt, |
| const std::unordered_map<const Variable*, Expr>& value_map); |
| |
| /*! |
| * \brief Substitute the var specified in key->var to be value. |
| * \param expr The source expression to be substituted |
| * \param value_map The map of new values. |
| * \return The converted expression. |
| */ |
| Expr Substitute(Expr expr, |
| const std::unordered_map<const Variable*, Expr>& value_map); |
| |
| /*! |
| * \brief Substitute the var specified in key->var to be value. |
| * \param stmt The source statement to be substituted |
| * \param value_map The map of new values. |
| * \return The converted form. |
| */ |
| Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map); |
| |
| /*! |
| * \brief Substitute the var specified in key->var to be value. |
| * \param expr The source expression to be substituted |
| * \param value_map The map of new values. |
| * \return The converted expression. |
| */ |
| Expr Substitute(Expr expr, const Map<Var, Expr>& value_map); |
| |
| /*! |
| * \brief inline all calls of f in stmt. |
| * |
| * \param stmt The statement to apply inline optimization. |
| * \param f The function reference to be inlined |
| * \param args The arguments variable of the function. |
| * \param body The definition body of the function. |
| * \return The result stmt |
| * |
| * \note All the passes in this file uses SSA form and outputs SSA form. |
| */ |
| Stmt Inline(Stmt stmt, |
| FunctionRef f, |
| Array<Var> args, |
| Expr body); |
| |
| /*! |
| * \brief Flatten the multi-dimensional read/write |
| * to single dimensional Load/Store |
| * |
| * \param stmt The stmt to be trasnformed. |
| * \param extern_buffer Map specifies external |
| * buffer assignment of input and outputs. |
| * \param cache_line_size The size of CPU cache line. |
| * \param create_bound_attribute Whether to create bound attributes. |
| * \return Transformed stmt. |
| */ |
| Stmt StorageFlatten(Stmt stmt, |
| Map<Tensor, Buffer> extern_buffer, |
| int cache_line_size, |
| bool create_bound_attribute = false); |
| |
| /*! |
| * \brief Remove No Op from the Stmt. |
| * \param stmt The stmt to be trasnformed |
| * \return Transformed stmt. |
| */ |
| Stmt RemoveNoOp(Stmt stmt); |
| |
| /*! |
| * \brief Split statement into pipeine stages. |
| * \param stmt The stmt to be splitted |
| * \param split_load Whether split load into its own stage. |
| * \return Transformed stmt. |
| */ |
| Stmt SplitPipeline(Stmt stmt, bool split_load); |
| |
| /*! |
| * \brief Narrow channel access to smaller range. |
| * \param stmt The stmt to do access rewriting. |
| * \return Transformed stmt. |
| */ |
| Stmt NarrowChannelAccess(Stmt stmt); |
| |
| /*! |
| * \brief unroll the constant loop marked by unroll. |
| * This pass also automatically attach pragma unroll tag to loops which meets the standard. |
| * |
| * \param stmt The statment to be unrolled. |
| * \param auto_max_step The maximum step before stop attach automatic unroll |
| * \param auto_max_depth The maximum depth before stop attach automatic unroll |
| * \param auto_max_extent The maximum extent of the loop we can unroll, |
| * this is an legacy option that do not take the loop total steps into account. |
| * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen. |
| * \return Transformed stmt. |
| */ |
| Stmt UnrollLoop(Stmt stmt, |
| int auto_max_step, |
| int auto_max_depth, |
| int auto_max_extent, |
| bool explicit_unroll); |
| |
| /*! |
| * \brief vectorize the constant loops |
| * \param stmt The statment to be vectorized. |
| * \return Transformed stmt. |
| */ |
| Stmt VectorizeLoop(Stmt stmt); |
| |
| /*! |
| * \brief instruments bound checkers. |
| * \param stmt The statment to be instrumented. |
| * \return Instrumented Stmt. |
| */ |
| Stmt InstrumentBoundCheckers(Stmt stmt); |
| |
| /*! |
| * \brief Inject virtual thread loops into stmt. |
| * \param stmt The statment to be transformed. |
| * \return Transformed stmt. |
| */ |
| Stmt InjectVirtualThread(Stmt stmt); |
| |
| /*! |
| * \brief Inject prefetch instructions into stmt. |
| * \param stmt The statment to be transformed. |
| * \return Transformed stmt. |
| */ |
| Stmt InjectPrefetch(Stmt stmt); |
| |
| /*! |
| * \brief Inject double buffer into stmt. |
| * \param stmt The statment to be transformed. |
| * \param split_loop Loop splitting factor. |
| * \return Transformed stmt. |
| */ |
| Stmt InjectDoubleBuffer(Stmt stmt, int split_loop); |
| |
| /*! |
| * \brief Inject copy intrinsics with optional pad. |
| * |
| * \param stmt The statment to be transformed. |
| * \param pragma_key The pragma key for hint of copy. |
| * \param fintrin The function with signature |
| * |
| * Stmt fintrin(Buffer src, |
| * Buffer dst, |
| * Array<Expr> pad_before, |
| * Array<Expr> pad_after, |
| * Expr pad_value) |
| * \return Transformed stmt. |
| */ |
| Stmt InjectCopyIntrin(Stmt stmt, |
| const std::string& pragma_key, |
| const runtime::PackedFunc& fintrin); |
| |
| /*! |
| * \brief Rewrite storage allocation pattern. |
| * Moves the allocation to outer most possible scope. |
| * Trying to share space between allocations to make |
| * a static allocation plan when possible. |
| * |
| * \param stmt The stmt to be trasnformed |
| * \return Transformed stmt. |
| */ |
| Stmt StorageRewrite(Stmt stmt); |
| |
| /*! |
| * \brief partition loops in the stmt |
| * \param stmt The stmt to do loop partition |
| * \param split_const_loop flag to enable partition for const loop |
| * \return Transformed stmt. |
| */ |
| Stmt LoopPartition(Stmt stmt, bool split_const_loop); |
| |
| /*! |
| * \brief Detect and insert sync points to co-processor. |
| * |
| * \param stmt The stmt to be trasnformed |
| * \return Transformed stmt. |
| */ |
| Stmt CoProcSync(Stmt stmt); |
| |
| /*! |
| * \brief Lift common attrs with attr_key to outer scope. |
| * |
| * \param stmt The stmt to be trasnformed |
| * \param attr_key The attribute key to be checked. |
| * \return Transformed stmt. |
| */ |
| Stmt LiftAttrScope(Stmt stmt, std::string attr_key); |
| |
| /*! |
| * \brief Detect and rewrite unsafe select that contains memory access. |
| * \param stmt The statment to be rewritten. |
| * \return Transformed stmt. |
| */ |
| Stmt RewriteUnsafeSelect(Stmt stmt); |
| |
| /*! |
| * \brief Lower attached storage access information. |
| * Do this pass after all storage access analysis finish. |
| * |
| * \param stmt The stmt to be trasnformed |
| * \return Transformed stmt. |
| */ |
| Stmt LowerStorageAccessInfo(Stmt stmt); |
| |
| /*! |
| * \brief Decorate the stmt with a device scope, this is helpful for |
| * hardware accelerator without thread blocks. |
| * |
| * \param stmt The stmt to be trasnformed |
| * \return Transformed stmt. |
| */ |
| Stmt DecorateDeviceScope(Stmt stmt); |
| |
| /*! |
| * \brief Make an user callable API LoweredFunc. |
| * |
| * The main task of this function is to create code to : |
| * - Map the values in the api_args to of Var that is required by body. |
| * - Insert assertions to check type/value of the passed arguments. |
| * |
| * \param body The body of the function. |
| * \param name The name of the function. |
| * \param api_args Arguments to the function, can be either Var, or Buffer |
| * \param num_unpacked_args Number of arguments that |
| * are processed in plain form instead of packed form. |
| * \param is_restricted Whether the caller can guarantee that each buffer argument do not overlap. |
| * It is recommended to set to true for optimized code if such invariant holds. |
| * |
| * \return a LoweredFunc with the specified signiture. |
| * |
| * \note |
| * The function signiture have two cases |
| * |
| * let num_packed_args = len(api_args) - num_unpacked_args; |
| * |
| * if num_packed_args is zero: |
| * f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args) |
| * |
| * if num_packed_args is not zero: |
| * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, |
| * api_arg_k, api_arg_k+1, ... api_arg_n) |
| * |
| * where n == len(api_args), k == num_packed_args |
| * |
| * There is no thread_axis in generated function. |
| */ |
| LoweredFunc MakeAPI(Stmt body, |
| std::string name, |
| Array<NodeRef> api_args, |
| int num_unpacked_args, |
| bool is_restricted); |
| |
| /*! |
| * \brief Bind the device type of host function to be device_type. |
| * \param func The function to be binded. |
| * \param device_type The device type to be binded. |
| * \return The binded function. |
| */ |
| LoweredFunc BindDeviceType(LoweredFunc func, |
| int device_type); |
| /*! |
| * \brief Find undefined vars in the statment. |
| * \param stmt The function to be checked. |
| * \param defs The vars that is defined. |
| * \return Array of undefined vars. |
| */ |
| Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs); |
| |
| /*! |
| * \brief Split the function into a host function and device functions. |
| * \param func The function to be splitted. |
| * |
| * \return Array of functions, the first one is host function, |
| * the others are device functions. |
| */ |
| Array<LoweredFunc> SplitHostDevice(LoweredFunc func); |
| |
| /*! |
| * \brief Insert sync between parallel read/write of shared buffers. |
| * |
| * \param stmt The stmt to be trasnformed. |
| * \param storage_scope The storage scope considered. |
| */ |
| LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope); |
| |
| /*! |
| * \brief Lower cross thread alleduce in the stmt. |
| * \param f The device function to be lowered. |
| * \param warp_size the size of warp where no sync is needed. |
| * \return Transformed function. |
| */ |
| LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size); |
| |
| /*! |
| * \brief Lower warp memory in stmt. |
| * \param f The device function to be lowered. |
| * \param warp_size the size of warp where no sync is needed. |
| * this function will only take in effect if warp_size is bigger than one. |
| * \return Transformed function. |
| */ |
| LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size); |
| |
| /*! |
| * \brief Remap the thread axis |
| * |
| * This can be used to get equivalent program which uses |
| * threadIdx.y in place of threadIdx.x by passing |
| * {"threadIdx.x": thread_axis("threadIdx.y")} |
| * |
| * |
| * \param f The device function to be lowered. |
| * \param axis_map The map from StringImm -> ItrVar |
| * \return Transformed function. |
| */ |
| LoweredFunc RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> axis_map); |
| |
| /*! |
| * \brief Lower packed function call. |
| * \param f The function to be lowered. |
| * \return Transformed function. |
| */ |
| LoweredFunc LowerTVMBuiltin(LoweredFunc f); |
| |
| /*! |
| * \brief Combine context function calls. |
| * \param f The host function to be lowered. |
| * \return Transformed function. |
| */ |
| LoweredFunc CombineContextCall(LoweredFunc f); |
| |
| /*! |
| * \brief Rewrite the pointer content type of arguments, |
| * as well as Alloc internal to the function to use |
| * the most frequently accessed type for load/store |
| * to avoid pointer casting in backend when possible. |
| * |
| * \note implemeneted in storage_rewrite.cc |
| * \param f The function to be trasnformed |
| * \return Transformed function. |
| */ |
| LoweredFunc PointerValueTypeRewrite(LoweredFunc f); |
| |
| /*! |
| * \brief Lower intrinsic function calls. |
| * \param f The device function to be lowered. |
| * \param target The target device. |
| * \return Transformed function. |
| */ |
| LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target); |
| |
| /*! |
| * \brief Verify if memory accesses are legal for a specific target device type. |
| * |
| * In the case that tgt is cuda, if not all workload is bound with |
| * threads, CPU code is generated that tries to access GPU memory, |
| * which is illegal. This pass performs verification for this case. |
| * |
| * \param func The function to be verified. |
| * \param device_type The target device type. |
| * \return Success of memory verification. |
| */ |
| bool VerifyMemory(LoweredFunc func, int device_type); |
| |
| |
| /*! |
| * \brief Verify the correctness of a GPU code |
| * It will check the whether the amount of memory usage or the number of threads |
| * in a block exceeds the limit |
| * \param stmt The statement to be checked |
| * \param constraints The dict to specify constraints to check. |
| * Possible keys are |
| * |
| * "max_local_memory_per_block": Total amount of local memory per block (in bytes). |
| * "max_shared_memory_per_block": Total amount of shared memory per block (in bytes). |
| * "max_threads_per_block": Maximum number of threads per block. |
| * "max_thread_x": Maximum length of threadIdx.x. |
| * "max_thread_y": Maximum length of threadIdx.y. |
| * "max_thread_z": Maximum length of threadIdx.z. |
| * |
| * If one key is missing in this argument, the pass won't check for that item. |
| * \return valid Whether it is a valid GPU code |
| * |
| */ |
| bool VerifyGPUCode(Stmt stmt, |
| Map<std::string, Expr> constraints); |
| |
| |
| } // namespace ir |
| } // namespace tvm |
| |
| #endif // TVM_IR_PASS_H_ |