| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file tvm/tir/transform.h |
| * \brief TIR specific transformation passes. |
| */ |
| #ifndef TVM_TIR_TRANSFORM_H_ |
| #define TVM_TIR_TRANSFORM_H_ |
| |
| #include <tvm/ir/transform.h> |
| #include <tvm/target/target.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/function.h> |
| |
| #include <string> |
| #include <vector> |
| |
| namespace tvm { |
| namespace tir { |
| namespace transform { |
| |
| using tvm::transform::Pass; |
| using tvm::transform::PassContext; |
| using tvm::transform::PassContextNode; |
| using tvm::transform::PassInfo; |
| using tvm::transform::PassInfoNode; |
| using tvm::transform::PassNode; |
| using tvm::transform::Sequential; |
| |
| /* |
| * \brief Create a function pass that optimizes PrimFuncs. |
| * |
| * \param pass_func The packed function that contains the optimization. |
| * \param opt_level The optimization level of the function pass. |
| * \param name The name of the function pass. |
| * \param required The list of the passes that the function pass is dependent on. |
| * |
| * \return The created function pass. |
| */ |
| TVM_DLL Pass CreatePrimFuncPass( |
| const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func, |
| int opt_level, String name, tvm::Array<String> required); |
| |
| /*! |
| * \brief Inject prefetch instructions into stmt. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass InjectPrefetch(); |
| |
| // TODO(tvm-team): consolidate configs to the PassContext |
| /*! |
| * \brief Flatten the multi-dimensional read/write |
| * to single dimensional Load/Store |
| * |
| * \param cache_line_size The size of CPU cache line. |
| * \param create_bound_attribute Whether to create bound attributes. |
| * |
| * \return The Pass |
| */ |
| TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = false); |
| |
| /*! |
| * \brief Inject copy intrinsics with optional pad. |
| * |
| * \param pragma_key The pragma key for hint of copy. |
| * \param fintrin The function with signature |
| * |
| * Stmt fintrin(Buffer src, |
| * Buffer dst, |
| * Array<Expr> pad_before, |
| * Array<Expr> pad_after, |
| * Expr pad_value) |
| * \return The pass. |
| */ |
| TVM_DLL Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin); |
| |
| /*! |
| * \brief Detect and insert sync points to co-processor. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass CoProcSync(); |
| |
| /*! |
| * \brief Lift common attrs with attr_key to outer scope. |
| * |
| * \param attr_key The attribute key to be checked. |
| * \return The pass. |
| */ |
| TVM_DLL Pass LiftAttrScope(String attr_key); |
| |
| /*! |
| * \brief partition loops in the stmt. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass LoopPartition(); |
| |
| /*! |
| * \brief Lower vectorization loops. |
| * |
| * \param enable_vectorize Whether vectorization is enabled. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true); |
| |
| /*! |
| * \brief Inject virtual thread loops. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass InjectVirtualThread(); |
| |
| /*! |
| * \brief Inject double buffer statements. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass InjectDoubleBuffer(); |
| |
| /*! |
| * \brief Rewrite storage allocation pattern. |
| * Moves the allocation to outer most possible scope. |
| * Trying to share space between allocations to make |
| * a static allocation plan when possible. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass StorageRewrite(); |
| |
| /*! |
| * \brief unroll the constant loop marked by unroll. |
| * This pass also automatically attach pragma unroll tag to loops which meets the standard. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass UnrollLoop(); |
| |
| /*! |
| * \brief Remove No Op from the Stmt. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass RemoveNoOp(); |
| |
| /*! |
| * \brief Detect and rewrite unsafe select that contains memory access. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass RewriteUnsafeSelect(); |
| |
| /*! |
| * \brief Run arithmetic simplifications on the statements and expressions. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass Simplify(); |
| |
| /*! |
| * \brief Instruments bound checkers. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass InstrumentBoundCheckers(); |
| |
| /*! |
| * \brief Transform the high-level PrimFunc to a low-level version |
| * that can be used as an API function. |
| * |
| * |
| * The main task of this function is to create code to : |
| * - Map the values in the api_args to Var that is required by body. |
| * - Insert assertions to check type/value of the passed arguments. |
| * |
| * \param num_unpacked_args Number of arguments that |
| * are processed in plain form instead of packed form. |
| * |
| * \note |
| * The function signature have two cases |
| * |
| * let num_packed_args = len(api_args) - num_unpacked_args; |
| * |
| * if num_packed_args is zero: |
| * f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args) |
| * |
| * if num_packed_args is not zero: |
| * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, |
| * api_arg_k, api_arg_k+1, ... api_arg_n, |
| * TVMValue* out_ret_val, int* out_ret_tcode) |
| * |
| * where n == len(api_args), k == num_packed_args |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass MakePackedAPI(int num_unpacked_args); |
| |
| /*! |
| * \brief Transform the high-level PrimFunc to a C signature that can be used |
| * to call the operator directly. |
| * |
| * The main task of this function is to create code that maps the values in the |
| * api_args to Var that is required by body |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass MakeUnpackedAPI(); |
| |
| /*! |
| * \brief Remap the thread axis |
| * |
| * This can be used to get equivalent program which uses |
| * threadIdx.y in place of threadIdx.x by passing |
| * {"threadIdx.x": thread_axis("threadIdx.y")} |
| * |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass RemapThreadAxis(Map<String, IterVar> axis_map); |
| |
| /*! |
| * \brief Lower custom datatypes. |
| * |
| * See tvm::datatypes::Registry for more information on adding custom datatypes. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass LowerCustomDatatypes(); |
| |
| /*! |
| * \brief Decorate all the function's body as device function. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass DecorateDeviceScope(); |
| |
| /*! |
| * \brief Split the function into a host function and device functions. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass SplitHostDevice(); |
| |
| /*! |
| * \brief skip assert stmt. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass SkipAssert(); |
| |
| /*! |
| * \brief Insert sync between parallel read/write of shared buffers. |
| * |
| * \param storage_scope The storage scope considered. |
| * \return The pass. |
| */ |
| TVM_DLL Pass ThreadSync(String storage_scope); |
| |
| /*! |
| * \brief Lower cross thread alleduce. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass LowerThreadAllreduce(); |
| |
| /*! |
| * \brief Infer the TensorCore fragment infomation using tensor intrinsics |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass InferFragment(); |
| |
| /*! |
| * \brief This annotation is for nodes to be disabled for builtin lowering |
| */ |
| static constexpr const char* kDisableLowerTVMBuiltin = "disable_lower_builtin"; |
| |
| /*! |
| * \brief Lower builtin intrinsics. |
| * \return The pass. |
| */ |
| TVM_DLL Pass LowerTVMBuiltin(); |
| |
| /*! |
| * \brief Lower the target specific function intrinsics in each of the function. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass LowerIntrin(); |
| |
| /*! |
| * \brief Lower warp memory access to low-level device related function calls. |
| * \return The pass. |
| */ |
| TVM_DLL Pass LowerWarpMemory(); |
| |
| /*! |
| * \brief Lower attached storage access information on device. |
| * |
| * \note Run this pass after all storage access analysis finish. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass LowerDeviceStorageAccessInfo(); |
| |
| /*! |
| * \brief Combine context calls in the host function. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass CombineContextCall(); |
| |
| /*! |
| * \brief Narrow down PrimExpr datatype in stmt to target_bits. |
| * |
| * \param target_bits The target bits |
| * |
| * \note Run this pass after storage flatten. |
| * \return The pass. |
| */ |
| TVM_DLL Pass NarrowDataType(int target_bits); |
| |
| /*! |
| * \brief Legalize bf16 typed Ops. Add a cast to fp32 |
| * before Ops, then add a cast back to bf16. |
| * \return The pass. |
| */ |
| TVM_DLL Pass BF16Legalize(); |
| |
| /*! |
| * \brief Rewrite the pointer content type of arguments, |
| * as well as Alloc internal to the function to use |
| * the most frequently accessed type for load/store |
| * to avoid pointer casting in backend when possible. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass PointerValueTypeRewrite(); |
| |
| /*! |
| * \brief Hoist loop-invariant IfThenElse nodes to |
| * outside the elligible loops. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass HoistIfThenElse(); |
| |
| /*! |
| * \brief Hoist loop-invariant expressions nodes to |
| * outside the elligible loops. |
| * |
| * Can hoist conditionals used in IfThenElse statements and |
| * expressions, bindings of variables in Let statements and |
| * expressions, or boolean expressions, configurable to enable/disable |
| * each hoistable type. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass HoistExpression(); |
| |
| /*! |
| * \brief Lower cross-thread reduction from thread |
| * bindings to intrinsic function calls. |
| * \return The pass. |
| */ |
| TVM_DLL Pass LowerCrossThreadReduction(); |
| |
| /*! |
| * \brief Lower block init stmt into IfThenElse stmts |
| * \return The pass. |
| */ |
| TVM_DLL Pass LowerInitBlock(); |
| |
| /*! |
| * \brief Locate the buffer allocation to the exact position (usually is |
| * the lca of buffer access). This pass will inject opaque block |
| * with alloc_buffers at the allocation site. |
| * \return The pass. |
| */ |
| TVM_DLL Pass PlanAndUpdateBufferAllocationLocation(); |
| |
| /*! |
| * \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the |
| * corresponding iter_values in BlockRealize, for opaque blocks by removing all |
| *. the iter_values in BlockRealize and iter_vars in Block. |
| * \return The pass. |
| */ |
| TVM_DLL Pass ConvertBlocksToOpaque(); |
| |
| /*! |
| * \brief Compact the buffer access region by removing the buffer regions that are not accessed, |
| * i.e. narrowing the buffer shape and adjust the access region if necessary. |
| * |
| * Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed. |
| * |
| * \code |
| * |
| * for i in range(0, 16): |
| * with T.block(): |
| * B = T.alloc_buffer(16, 16) |
| * for j in range(0, 16): |
| * B[i, j] = A[i, j] + 1 |
| * for j in range(0, 16): |
| * C[i, j] = B[i, j] + 1 |
| * |
| * \endcode |
| * |
| * This pass narrows the buffer shape and adjust its accessed region accordingly. |
| * In this particular case, because only a `1 * 16` vector of `B` is accessed, |
| * the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, j]` to `B[0, j]`. |
| * |
| * \code |
| * |
| * for i in range(0, 16): |
| * with T.block(): |
| * B = T.alloc_buffer(1, 16) |
| * for j in range(0, 16): |
| * B[0, j] = A[i, j] + 1 |
| * for j in range(0, 16): |
| * C[i, j] = B[0, j] + 1 |
| * |
| * \endcode |
| * |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass CompactBufferAllocation(); |
| |
| /*! |
| * This pass legalizes packed calls by wrapping their arguments into TVMValues |
| */ |
| TVM_DLL Pass LegalizePackedCalls(); |
| |
| /*! |
| * \brief Remove match buffers inside the block. Also, it will validate the binding. |
| * \return The pass. |
| */ |
| TVM_DLL Pass LowerMatchBuffer(); |
| |
| /*! |
| * \brief Remove the block to ensure that the TIR can not be scheduled again. |
| * \return The pass. |
| */ |
| TVM_DLL Pass LowerOpaqueBlock(); |
| |
| /*! |
| * \brief Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional |
| * BufferLoad/BufferStore for the TIR not contains opaque block. |
| * \return The pass. |
| */ |
| TVM_DLL Pass FlattenBuffer(); |
| |
| /* |
| * \brief Flatten the multi-dimensional read/write |
| * to two dimensional texture Load/Store and realize |
| * texture buffer allocations. |
| * |
| * \return The Pass |
| */ |
| TVM_DLL Pass TextureFlatten(); |
| |
| /* |
| * \brief Lower VTCM allocations |
| * |
| * \return The Pass |
| */ |
| TVM_DLL Pass LowerVtcmAlloc(); |
| |
| /*! |
| * \brief Lower Async TIR primitives to DMA copy and wait builtins |
| */ |
| TVM_DLL Pass LowerAsyncDMA(); |
| |
| /*! |
| * \brief Implements a Common Subexpression Elimination (CSE) for TIR |
| * which introduces let-in bindings for duplicated sub-expressions. |
| * \param enable_cse_tir Whether common subexpression elimination is enabled. |
| * \param identify_equiv_terms Whether equivalent terms should be identified. |
| * \return The pass. |
| */ |
| TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false); |
| |
| /*! |
| * \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and |
| * "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g., |
| * "threadIdx.x") use different IterVars and variables in their AttrStmts. After the |
| * unification, we use a consolidated IterVar and a variable for them. |
| * \return The pass. |
| * \note `vthread` is a legacy behavior that will be deprecated, though thread bindings of `vthread` |
| * are still also unified in this pass. Please use `vthread.x`, `vthread.y` and `vthread.z` |
| * instead. |
| */ |
| TVM_DLL Pass UnifyThreadBinding(); |
| |
| /*! |
| * A pass to merge multiple TIR-level dynamic shared memory allocations into one |
| */ |
| TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); |
| |
| /*! |
| * \brief This pass is post-scheduling pass to convert all |
| * Parallel For loops to Serial ones. This is run |
| * to attain lesser memory and/or executor/backend |
| * does not support parallel launch of For loops. |
| * \return The pass. |
| */ |
| TVM_DLL Pass ConvertForLoopsToSerial(); |
| |
| /*! |
| * \brief This is the unified static memory planner pass that will |
| * plan for memory intra- and inter- PrimFuncs together. The pass |
| * requires all the function to be PrimFuncs including the main. |
| * \return The pass. |
| */ |
| TVM_DLL Pass UnifiedStaticMemoryPlanner(); |
| |
| /*! |
| * \brief This pass transforms annotated loops into pipelined ones where producers and consumers |
| * are overlapped with the information provided in loop annotations, which enables optimization |
| * techniques like prefetching and pipeline parallelism. |
| * |
| * The pipeline scope consists of the direct children of the annotated loop (ignoring BlockRealize, |
| * Block, SeqStmt), and the number of children is denoted by `n` in the documentation. |
| * |
| * The following annotations are used to guide the loop transformation: |
| * |
| * 1) Loop annotation `software_pipeline_stage` defines the pipeline stage. |
| * An array of `n` integers, and each element should be in range [0, max_stage], |
| * where max_stage is the maximum (inclusive) stage. |
| * 2) Loop annotation `software_pipeline_order` defines the pipeline order. |
| * An array of `n` integers, a permutation of [0, 1, ..., num_components - 1]; |
| * 3) Block annotation `double_buffer_scope` controls certain buffer sizes to allow decoupling of |
| * read/write dependency. It's an integer index of the write regions of the block. |
| * |
| * Every annotated loop is transformed into a loop with three blocks as its direct children: |
| * |
| * 1) Prologue block, where components whose stage is less than `max_stage` is executed; |
| * |
| * 2) Body block, where all the components are executed; |
| * |
| * 3) Epilogue block, where only components whose stage is greater than 0 will be executed. |
| * The execution order is controlled by the annotation `software_pipeline_order`, |
| * and thus could be different than the original order. |
| * |
| * Note: For nested software pipelines, the inner software pipeline will be generated first, |
| * which may affect the number of the direct children of the outer loop. |
| * In this case, the annotations for the outer software |
| * pipeline should include the result of the inner software pipeline, |
| * which is the three blocks as discussed above. |
| * Example: |
| * |
| * Before this pass, the TIR is: |
| * |
| * \code{.py} |
| * @T.prim_func |
| * def before_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: |
| * for tx in T.thread_binding(0, 16, thread="threadIdx.x"): |
| * for i in T.serial(0, 16, |
| * annotations={"software_pipeline_stage": [0, 1], |
| * "software_pipeline_order": [0, 1]} |
| * ): |
| * with T.block(): |
| * T.reads(A[tx, i]) |
| * T.writes(C[tx, i]) |
| * B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") |
| * with T.block("B"): |
| * T.reads(A[tx, i]) |
| * T.writes(B[tx, 0]) |
| * B[tx, 0] = A[tx, i] * T.float32(2) |
| * with T.block("C"): |
| * T.reads(B[tx, 0]) |
| * T.writes(C[tx, i]) |
| * C[tx, i] = B[tx, 0] + T.float32(1) |
| * \endcode |
| * |
| * The TIR above annotates the loop as a two-stage pipeline with no reordering. |
| * After applying this pass, the TIR is transformed into: |
| * |
| * \code{.py} |
| * @T.prim_func |
| * def after_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: |
| * for tx in T.thread_binding(0, 16, thread="threadIdx.x"): |
| * with T.block(): |
| * T.reads([A[tx, 0:16]]) |
| * T.writes([C[tx, 0:16]]) |
| * B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") |
| * with T.block("prologue"): |
| * T.reads([A[tx, 0]]) |
| * T.writes([B[0, tx, 0]]) |
| * B[0, tx, 0] = A[tx, 0] * T.float32(2) |
| * with T.block("body"): |
| * T.reads([A[tx, 1:16], B[0:2, tx, 0]]) |
| * T.writes([B[0:2, tx, 0], C[tx, 0:15]]) |
| * for i in T.serial(0, 15): |
| * with T.block("B"): |
| * T.reads([A[tx, i + 1]]) |
| * T.writes([B[(i + 1) % 2, tx, 0]]) |
| * B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) |
| * with T.block("C"): |
| * T.reads([B[i % 2, tx, 0]]) |
| * T.writes([C[tx, i]]) |
| * C[tx, i] = B[i % 2, tx, 0] + T.float32(1) |
| * with T.block("epilogue"): |
| * T.reads([B[1, tx, 0]]) |
| * T.writes([C[tx, 15]]) |
| * C[tx, 15] = B[1, tx, 0] + T.float32(1) |
| * \endcode |
| * |
| * The original loop has two blocks, B and C, as its direct children. The loop annotations indicate |
| * that block B has stage == 0, order == 0, block C has stage == 1, order == 1. Therefore, block B |
| * should be executed in advance of block C by one iteration. The order 0 and 1 specifies the order |
| * of block B and C inside the body block inside the result TIR. |
| * |
| * \return The IR transform pass. |
| */ |
| TVM_DLL Pass InjectSoftwarePipeline(); |
| |
| TVM_DLL Pass BindParams(const Array<runtime::NDArray>& constants); |
| |
| /*! |
| * \brief Pass to collect tir non-scalar constants into module's 'Constants' attribute. |
| * |
| * \return The pass. |
| */ |
| TVM_DLL Pass ExtractPrimFuncConstants(); |
| |
| /*! |
| * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) |
| * \return The pass. |
| */ |
| TVM_DLL Pass RenormalizeSplitPattern(); |
| |
| /*! |
| * \brief Annotate a PrimFunc with a given target. |
| * \return The pass. |
| */ |
| TVM_DLL Pass BindTarget(Target target); |
| |
| /*! |
| * \brief Set a PrimFunc as the entry point if it is only function in IRModule. |
| * \return The pass. |
| */ |
| TVM_DLL Pass AnnotateEntryFunc(); |
| |
| /*! |
| * \brief Filter PrimFuncs with a given condition. |
| * \return The pass. |
| */ |
| TVM_DLL Pass Filter(runtime::TypedPackedFunc<bool(PrimFunc)> fcond); |
| |
| /*! |
| * \brief Pass to rewrite global to shared memory copy on CUDA with asyncronous copy. |
| * \return The pass. |
| */ |
| TVM_DLL Pass InjectPTXAsyncCopy(); |
| |
| /*! |
| * \brief Remove the weight layout rewrite block |
| * \return The pass. |
| */ |
| TVM_DLL Pass RemoveWeightLayoutRewriteBlock(); |
| |
| /*! |
| * \brief Add the explicit local stage for the shared memory access on GPU. |
| * \return The pass. |
| */ |
| TVM_DLL Pass ManifestSharedMemoryLocalStage(); |
| |
| } // namespace transform |
| } // namespace tir |
| } // namespace tvm |
| |
| #endif // TVM_TIR_TRANSFORM_H_ |