blob: 402fa551543138adb36c8b7f12f722a2140d1493 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/tir/analysis.h
* \brief Analysis utilities and passes for TIR.
*/
#ifndef TVM_TIR_ANALYSIS_H_
#define TVM_TIR_ANALYSIS_H_
#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt.h>
#include <string>
namespace tvm {
namespace tir {
/*!
* \brief Compare two expressions recursively and check if they are equal
* to each other without var remapping.
*
* This function does not remap variable bindings, it will not
* return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y).
*
* Use StructuralEqual for such cases.
*
* Due to the restriction of not remapping variables, this function can run
* faster than StructuralEqual and can be used as a utility function during arithmetic
* simplifications.
*
* \sa StructuralEqual
*/
struct ExprDeepEqual {
public:
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
};
/*!
* \brief Visit the PrimFuncs in the IRModule
* \tparam FLambda The type of the PrimFunc visitor
* \param mod The IRModule to be visited
* \param fvisit The visitor to the PrimFuncs in the IRModule
*/
template <class FLambda>
inline void VisitPrimFuncs(const IRModule& mod, FLambda fvisit) {
for (const auto& kv : mod->functions) {
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
fvisit(prim_func);
}
}
}
/*!
* \brief Estimate the FLOPs of a TIR fragment.
* \param stmt The TIR fragment to be estimated.
* \return The estimated FLOPs.
*/
TVM_DLL double EstimateTIRFlops(const Stmt& stmt);
/*!
* \brief Estimate the FLOPs of TIRs in an IRModule.
* \param mod The IRModule to be estimated.
* \return The estimated FLOPs.
*/
TVM_DLL double EstimateTIRFlops(const IRModule& mod);
/*!
* \brief Find undefined vars in the statement.
* \param stmt The function to be checked.
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
/*!
* \brief Find undefined vars in the expression.
* \param expr The expression to be checked.
* \return Array of undefined vars.
*/
TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
/*!
* \brief Analyze the side effect
* \param expr The expression to be checked.
*
* \return CallEffectKind, can be kPure, kReadState or kUpdateState
*/
TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
/*!
* \brief Whether the given Stmt uses any var in the given variable set.
* \param stmt The Stmt to be checked.
* \param vset_contains The check function to see if a var is in the variable set.
* \return Whether `stmt` uses any var in the given variable set.
*/
TVM_DLL bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> vset_contains);
/*!
* \brief Whether the given PrimExpr uses any var in the given variable set.
* \param expr The PrimExpr to be checked.
* \param vset_contains The check function to see if var is in the variable set.
* \return Whether `expr` uses any var in the given variable set.
*/
TVM_DLL bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
/*!
* \brief Verifies whether the IR stmt or Expr is in SSA form.
* That is: each Var is defined and assigned once(in Let/For)
*
* \param func The function to be verified.
* \return Whether IR is in SSA form.
*
* \note All passes in TIR consume and produce SSA form.
*/
TVM_DLL bool VerifySSA(const PrimFunc& func);
/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
* In the case that tgt is cuda, if not all workload is bound with
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal. This pass performs verification for this case.
*
* \param func The function to be verified.
* \return Success of memory verification.
*/
TVM_DLL bool VerifyMemory(const PrimFunc& func);
/*!
* \brief Verify the correctness of a GPU code
* It will check the whether the amount of memory usage or the number of threads
* in a block exceeds the limit
* \param func The function to be checked
* \param constraints The dict to specify constraints to check.
* Possible keys are
*
* "max_local_memory_per_block": Total amount of local memory per block (in bytes).
* "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
* "max_threads_per_block": Maximum number of threads per block.
* "max_thread_x": Maximum length of threadIdx.x.
* "max_thread_y": Maximum length of threadIdx.y.
* "max_thread_z": Maximum length of threadIdx.z.
*
* If one key is missing in this argument, the pass won't check for that item.
* \return valid Whether it is a valid GPU code
*
*/
TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);
/*!
* \brief Auto detect the block access region according to its body stmt
* It will detect the access region as an array in order of appearance in AST
* \param block The block to be detected
* \param buffer_var_map The outside buffers which may be accessed the block.
* It is a map from buffer var to the buffer.
* \return Array of access regions.
* There are three arrays of BufferRegion:
* - first: read regions
* - second: write regions
* - third: opaque regions
*/
TVM_DLL Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);
/*!
* \brief Auto detect the block read/write region according to its body stmt. An opaque access will
* be counted as both a read and a write access
* \param block The block to be detected
* \param buffer_var_map The outside buffers which may be accessed the block.
* It is a map from buffer var to the buffer
* \return An array only consisting of the read regions and write regions of the input block
*/
TVM_DLL Array<Array<BufferRegion>> GetBlockReadWriteRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);
/*!
* \brief Calculate the expresion complexity based on number of symbols it contains.
* \param expr The expr to be calculated.
*/
TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr);
/*!
* \brief Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc
* \param func The TIR PrimFunc for which the workspace size to be calculated
* \param workspace_byte_alignment The byte alignment required for each tensor allocated in this
* workspace
*/
TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
const Integer& workspace_byte_alignment);
/*!
* \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level
* access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access).
* The LCA may be a For loop or a Block.
* \param func The PrimFunc to be detected.
* \return The Map from buffer to the LCA of all access to it. The lca is function root if the
* return stmt is NullOpt.
*/
TVM_DLL Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func);
/*!
* \brief Verify if the given TIR is well-formed. The verification includes:
* - Check if expressions not contain vars that is defined outside the block.
* \param func The PrimFunc to be verified.
* \param assert_mode The indicator if it raises an error when the function is not well-formed.
* \return Whether it is a well-formed TIR function.
*/
TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true);
// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {
using tvm::transform::Pass;
using tvm::transform::PassContext;
/*!
* \brief Pass variant of VerifySSA.
*
* \returns The pass.
* \sa tvm::tir::VerifySSA
*/
TVM_DLL Pass VerifySSA();
/*!
* \brief Pass variant of VerifyMemory.
*
* \returns The pass.
* \sa tvm::tir::VerifyMemory
*/
TVM_DLL Pass VerifyMemory();
/*!
* \brief Pass variant of VerifyGPUCode.
*
* \param constraints The dict to specify constraints to check.
*
* \returns The pass.
* \sa tvm::tir::VerifyGPUCode
*/
TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);
/*!
* \brief Statically check TIR code for out of bounds array access.
*
* This analysis is conservative: it will only raise errors if it can prove
* that out of bounds access occurs. Cases that are uncertain do not raise
* errors.
*
* \returns The pass.
*/
TVM_DLL Pass OOBChecker();
} // namespace transform
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_ANALYSIS_H_