blob: d8f36e478669479b14e61154f645ddb8f831a687 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relax/transform.h
* \brief Relax specific transformation passes.
*/
#ifndef TVM_RELAX_TRANSFORM_H_
#define TVM_RELAX_TRANSFORM_H_
#include <tvm/ir/transform.h>
#include <tvm/relax/dataflow_pattern.h>
#include <tvm/relax/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/index_map.h>
namespace tvm {
namespace relax {
namespace transform {
using Pass = tvm::transform::Pass;
using PassInfo = tvm::transform::PassInfo;
using PassContext = tvm::transform::PassContext;
using Function = tvm::relax::Function;
using DataflowBlock = tvm::relax::DataflowBlock;
/*!
* \brief Create a function pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the function pass.
* \param name The name of the function pass.
* \param required The list of the passes that the function pass is dependent on.
* \param traceable Boolean variable whether the dataflowblock pass is traceable.
*
* \return The created function pass.
*/
TVM_DLL Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level, String name, tvm::Array<String> required, bool traceable = false);
/*!
* \brief Create a dataflowblock pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the dataflowblock pass.
* \param name The name of the dataflowblock pass.
* \param required The list of the passes that the dataflowblock pass is dependent on.
* \param traceable Boolean variable whether the dataflowblock pass is traceable.
*
* \return The created dataflowblock pass.
*/
TVM_DLL Pass CreateDataflowBlockPass(
const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)>& pass_func,
int opt_level, String name, tvm::Array<String> required, bool traceable = false);
/*!
* \brief Perform lambda lifting to lift functions from nested into global.
*
* \return The Pass.
*/
TVM_DLL Pass LambdaLift();
/*!
* \brief Transform all dataflow structure to non-dataflow version.
*
* \return The Pass.
*/
TVM_DLL Pass ToNonDataflow();
/*!
* \brief Activate force_pure on all pure functions in the module
* and unwrap all pure override ops into the normal versions.
*
* This effectively means that there will be no more purity tracking,
* useful for low-level code generation.
*
* \return The Pass.
*
* \note Should be used after ToNonDataflow()
*/
TVM_DLL Pass RemovePurityChecking();
/*!
* \brief Perform explicit tensor allocation for call_tir and call_dps_packed.
*
* \return The Pass.
*/
TVM_DLL Pass CallTIRRewrite();
/*!
* \brief Convert all reshape-like call_tir whose corresponding binding
* vars are DataflowVars to relax.reshape operator calls. The relax.reshape
* calls will be lowered an external builtin function call in a subsequent
* pass, where the external builtin function does a CreateView operation
* at runtime, instead of doing real data copy.
* Here "reshape-like" includes reshape, expand_dims, flatten, etc.
*
* \return The Pass.
* \note The pass is applied at the first stage of Relax VM build, before
* rewriting call_tir, as this pass requires dataflow information.
*/
TVM_DLL Pass RewriteDataflowReshape();
/*!
* \brief The static memory planning pass on BindingBlock level.
* The pass will reuse allocated memory to its best effort, in order to
* reduce the total amount of allocated memory size.
*
* The pass "supports" dynamic shape in the way of TIR variable upper bound
* annotation. We can optionally annotate the attribute "tir_var_upper_bound"
* to Relax functions. The attribute value is a dict from strings to integers,
* denoting the name of TIR variables to the upper bound values of the TIR vars.
* Note: The annotated upper bound attribute only applies to TIR vars in the
* function signature for clarity.
*
* For example, we can annotate a Relax function with
* `R.func_attr({"tir_var_upper_bound": {"n": 1024}})`.
* It means the maximum value of variable that names "n" in the function
* signature will have upper bound 1024. And we will use 1024 as its value
* during memory planning.
*
* \return The pass.
*/
TVM_DLL Pass StaticPlanBlockMemory();
/*!
* \brief Attach global_symbol to Relax functions and TIR Primfuncs for codegen.
*
* \return The Pass.
*/
TVM_DLL Pass AttachGlobalSymbol();
/*!
* \brief Transform Relax IR to normal form: transform AST to A-normal form, and fill the
* checked_type_ and shape_ of expressions.
*
* \return The Pass.
*/
TVM_DLL Pass Normalize();
/*!
* \brief Possibly rename the GlobalVar in an IRModule to ensure these properties:
* 1. (Invariant) First ensure every public function has the same name as its "global_symbol"
* attribute;
* 2. To ensure 1., we may need to rename private functions with conflicting names;
* 3. Finally, the name of every GlobalVar is unique in the IRModule.
*/
TVM_DLL Pass NormalizeGlobalVar();
/*!
* \brief Simplify a Relax module by folding var bindings and match shape nodes,
* as well as tuple indices.
* Best used alongside constant folding and eliminating unused bindings.
*
* \note If a dataflow var is used only in a binding to the dataflow block
* output var (i.e., a non-dataflow var), this pass will also remove the dataflow var
* and replaces the output var's binding with the dataflow var's direct definition.
*
* \return The Pass.
*/
TVM_DLL Pass CanonicalizeBindings();
/*!
* Eliminate common subexpressions within functions.
* \return The pass that eliminates common subexpressions.
*
* \note For nested functions, this pass performs CSE *within* those functions.
* \param call_only If true, enable eliminating only call nodes.
*/
TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
/*!
* \brief Bind params of function of the module to constant tensors.
*
* \param func_name The name of the function to bind parameters.
* \param params The parameters to bind.
*
* \return The Pass.
*/
TVM_DLL Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params);
/*!
* \brief Bind symbolic vars to constant shape values.
*
* \param binding_map The dictionary of symbolic variables and their
* constant shape values. Dictionary keys may be either a
* `tir.Var` or a string name of the variable. If the variables
* are referred to by name, the name must uniquely identify a
* symbolic variable in each function where it is used.
*
* \param func_name The name of the function in which to bind shape
* values. If NullOpt, all functions in the module will be
* updated.
*
* \return The Pass.
*/
TVM_DLL Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map,
Optional<String> func_name = NullOpt);
/*!
* \brief Fold constant expressions within dataflow blocks.
*
* \note ConvertToDataflow may need to be called first to provide dataflow blocks.
*
* \return The Pass.
*/
TVM_DLL Pass FoldConstant();
/*!
* \brief Legalize high-level operator calls in Relax functions to call_tir
* with corresponding low-level TIR PrimFuncs.
*
* For each high-level operator, we register the way of legalizing it as a
* function, which takes a context BlockBuilder and the Call being legalized
* as input, and returns the legalized call. Here the input BlockBuilder is
* mainly used for adding the PrimFunc created by call_te into the context
* IRModule.
*
* The legalization function for each operator is registered as an attribute (with
* attribute key `FLegalize`) of the operator.
*
* For customizability, the user can pass their own legalization by an optional customized map,
* with the key to be the operator name and value to be the legalization function.
* The default legalization function will be overridden by the customized one.
*
* \param cmap The customized operator legalization function map. The customized function
* will override the default one.
* \param enable_warning A boolean value indicating if to print warnings for TIR functions not
* showing up in the database.
* \return The Pass.
*/
TVM_DLL Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap, bool enable_warning = false);
/*!
* \brief Propagate virtual device information.
* \return The Pass.
*/
TVM_DLL Pass RealizeVDevice();
/*!
* \brief Lift transformation of the parameters of a function.
*
* When some inputs of the function is marked as 'parameters' (the model weights), this pass
* identifies the transformation of the parameters and lifts them to a separate function called
* `transform_params`. `transform_params` takes a tuple of the original parameters as input and
* returns a tuple of the transformed parameters. The original function will be rewritten to accept
* a tuple of transformed parameters as input.
*
* Users are expected to invoke the `transform_params` function in runtime and pass the transformed
* parameters to the original function as input.
*
* \param shared_transform Indicates how the parameter transformation function will be produced.
* - `False` (default): A separate parameter transformation function will be produced for each
* function with the `"num_input"` attribute.
*
* - `True`: A single parameter transformation function will be produced, containing the
* preprocessing steps common across all functions with the `"num_input"` attribute.
*
* - List[str]: A single parameter transformation function will be produced, containing the
* preprocessing steps common across each function whose name is in the list. Passing a list of
* all functions with the `"num_input"` attribute or an empty list is equivalent to passing
* `True`.
*
* \return The Pass.
*/
TVM_DLL Pass LiftTransformParams(Variant<Bool, Array<String>> shared_transform = Bool(false));
/*!
* \brief Update virtual device.
* \param new_vdevice The new virtual device.
* \param index The device index indicates the device on which the update will be performed.
* \return The Pass.
*/
TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index);
/*! \brief Expand tuple arguments to internal functions
*
* \return The Pass
*/
TVM_DLL Pass ExpandTupleArguments();
/*! \brief Remove unused parameters to internal functions
*
* \return The Pass
*/
TVM_DLL Pass RemoveUnusedParameters();
/*! \brief Remove unused outputs from internal functions
*
* \return The Pass
*/
TVM_DLL Pass RemoveUnusedOutputs();
/*!
* \brief Annotate Op Pattern Kind for TIR functions, which is used in FuseOps.
* \note It is an auto-detect pass for "unscheduled prim_funcs", the op_pattern will be
* "opaque" of we can't detect it. Users can manually annotate the attr `op_pattern`
* to prim_func.
* \return The Pass.
*/
TVM_DLL Pass AnnotateTIROpPattern();
/*!
* \brief This pass groups bindings in a dataflow block of Relax functions and generates a new
* grouped Relax function for each group, according to the fusion algorithm described in the pass
* implementation. By grouping bindings into new Relax functions, we substitute the bindings in the
* function being manipulated into function calls to the new grouped function.
*
* A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function.
* \param fuse_opt_level The level of fuse optimization.
* -1 indicates that the level will be inferred from pass context.
* \return The Pass.
*/
TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
/*!
* \brief The pattern object used as the input of FuseOpsByPattern. For bindings to be
* fused, it needs to be matched with `pattern` and the `check` function needs to return
* true.
*/
class FusionPatternNode : public Object {
public:
/*!
* \brief The name of pattern. It becomes the value of the kComposite attribute
* of a fused function after successful matching
*/
String name;
/*!
* \brief The dataflow pattern that will be used to match expression in the DataflowBlock.
* All the call nodes covered by the pattern will be extracted into the fused function.
*/
DFPattern pattern;
/*!
* \brief The map which is used to extract important expressions from the pattern match
* result. All DFPattern in this map should be part of the `pattern`.
*/
Map<String, DFPattern> annotation_patterns;
/*!
* \brief The function to determine whether the match result is accepted. This can be
* NullOpt if check function is not necessary for this pattern.
*
* It should have signature
* bool(const PatternCheckContext& context)
*/
Optional<PackedFunc> check;
/*!
* \brief The function to get attributes for fused function
*
* It should have signature
* Map<String, ObjectRef>(const Map<String, Expr>& context)
*/
Optional<PackedFunc> attrs_getter;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("pattern", &pattern);
v->Visit("annotation_patterns", &annotation_patterns);
v->Visit("check", &check);
v->Visit("attrs_getter", &attrs_getter);
}
static constexpr const char* _type_key = "relax.transform.FusionPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(FusionPatternNode, Object);
};
class FusionPattern : public ObjectRef {
public:
FusionPattern(String name, DFPattern pattern, Map<String, DFPattern> annotation_patterns,
Optional<PackedFunc> check, Optional<PackedFunc> attrs_getter);
FusionPattern(String name, DFPattern pattern)
: FusionPattern(name, pattern, {}, NullOpt, NullOpt) {}
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FusionPattern, ObjectRef, FusionPatternNode);
};
/*!
* \brief The input of FusionPattern::check.
*/
class PatternCheckContextNode : public Object {
public:
/*!
* \brief The expression that's matched with the FusionPattern::pattern.
*/
Expr matched_expr;
/*!
* \brief A map which contains all expressions matched by the sub patterns in
* FusionPattern::annotation_patterns.
*/
Map<String, Expr> annotated_expr;
/*!
* \brief Map from variable to its value. It contains variables from bindings that
* is being fused by FuseOpsByPattern.
*/
Map<Var, Expr> matched_bindings;
/*!
* \brief A map mapping variable definitions to a set of uses. It has all variables
* used in the function.
*/
Map<Var, Array<Var>> var_usages;
/*!
* \brief Map from value to its bound variable. It doesn't have variables after the
* matched expression.
*/
Map<Expr, Var> value_to_bound_var;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("matched_expr", &matched_expr);
v->Visit("annotated_expr", &annotated_expr);
v->Visit("matched_bindings", &matched_bindings);
v->Visit("var_usages", &var_usages);
v->Visit("value_to_bound_var", &value_to_bound_var);
}
static constexpr const char* _type_key = "relax.transform.PatternCheckContext";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternCheckContextNode, Object);
};
class PatternCheckContext : public ObjectRef {
public:
PatternCheckContext(Expr matched_expr, Map<String, Expr> annotated_expr,
Map<Var, Expr> matched_bindings, Map<Var, Array<Var>> var_usages,
Map<Expr, Var> value_to_bound_var);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternCheckContext, ObjectRef,
PatternCheckContextNode);
};
/*!
* \brief Reverse-mode automatic differentiation.
*
* This pass will differentiate one function in the IRModule. Now the input function must have only
* one dataflow block.
*
* For a given function specified by `func_name`, it generates a new function with the name
* `func_name + "_adjoint"`. The new function computes the gradient of the **differentiation
* target** with respect to the arguments specified by `require_grads` of the original function.
*
* If the function has only one return value, the return value will be specified as target. If the
* function has more than one return values, the target will be specified as the target_index-th
* return value. The target must be a scalar (0-dim tensor).
*
* \param func_name The name of the specified function.
* \param require_grads The relax variables whose adjoints is needed. Must be parameters of the
* given function and should not be duplicate. If it is not specified, adjoints of all parameters
* would be computed.
* \param target_index If the specified function has more than one return values, specify the index
* of the return value as the target. If it is not specified, the first return value will be the
* target.
* \return The Pass.
*
* \note ConvertToDataflow may need to be called first to provide dataflow blocks.
*/
TVM_DLL Pass Gradient(String func_name, Optional<Array<Var>> require_grads = NullOpt,
int target_index = 0);
/*!
* \brief Apply pattern matching to each function in the given module, and group matched
* expressions into a new function. The end result is similar to FuseOps, but fusion is driven
* completely by the provided patterns.
*
* \param patterns The patterns to detect. The order of the patterns determines the order
* of priority in which they are matched. Higher-priority patterns should come earlier in the list.
* \param bind_constants Whether or not to keep bound constants of the grouped function.
* \param annotate_codegen If true, wrap each created composite function with another function,
* whose body consists only of a call to the composite function, and annotate the outer function
* with kCodegen and kGlobalSymbol attributes. The kCodegen attribute is set as the prefix of the
* corresponding pattern name. For example, "dnnl" if the pattern name is "dnnl.conv2d_relu".
* This must be True if the created composite functions are intended to be offloaded to
* an external backend without using the MergeCompositeFunctions pass.
* \param entry_function_names The names of functions that should be considered as entry points. If
* not specified, all externally exposed functions will be considered as entry points.
* \return The Pass.
*
* \note Only operates within dataflow blocks. ConvertToDataflow may need to be called first.
*/
TVM_DLL Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants = true,
bool annotate_codegen = false,
const tvm::Array<String>& entry_function_names = {});
/*!
* \brief Group one or multiple composite functions created by FuseOpsByPattern into a new
* function. The new function will be annotated with kCodegen and GlobalSymbol attributes,
* and it is intented to be offloaded to an external backend.
*
* \return The Pass.
*/
TVM_DLL Pass MergeCompositeFunctions();
/*!
* \brief Fuse relax sub-function into a larger TIR function if possible.
this pass works together with FuseOps to perform operator fusion.
* \return The Pass.
*/
TVM_DLL Pass FuseTIR();
/*!
* \brief Run codegen.
* \param target_options pairs of target name and compilation options
* \param entry_functions list of entry functions
* \return The Pass.
*/
TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String, ObjectRef>>> target_options,
Array<runtime::String> entry_functions);
/*!
* \brief Decompose composite operators during inference. For example, The result of batch norm (a
* triple) will be simplified. Operators like Attention, Erf, etc. can be also simplified into
* several operators as well.
*
* \param func_name The name of the specified function. If not specified, the pass will run in
* all functions.
*/
TVM_DLL Pass DecomposeOpsForInference(Optional<String> func_name);
/*!
* \brief Decompose composite operators during training. For example, The result of batch norm (a
* triple) will be simplified. Operators like Attention, Erf, etc. can be also simplified into
* several operators as well.
*
* \param func_name The name of the specified function. If not specified, the pass will run in
* all functions.
*/
TVM_DLL Pass DecomposeOpsForTraining(Optional<String> func_name);
/*!
* \brief Returns a pass which replaces PrimFuncs which have matching kOperatorName attribute in \p
* op_impl_map, with replacement PrimFunc that could possibly have different layouts on i/o
* buffers. The layout transformations on i/o buffers is present in the \p op_buffer_transforms. The
* pass inserts the layout transformations in the call sites of PrimFuncs being replaced to
* transform i/o buffers into expected layout.
*
* \param op_impl_map Map from kOperatorName attr (e.g., relax.conv2d) to replacement PrimFunc
* \param op_buffer_transforms Map from kOperatorName attr to layout transformations on each of the
* PrimFunc i/o buffers.
* \param axis_separators Map from kOperatorName attr to axis_separators of each buffer_transforms
* \return The Pass.
*/
TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<tir::IndexMap>>& op_buffer_transforms,
const Map<String, Array<Array<IntImm>>>& axis_separators);
/*!
* \brief Layout conversion pass.
* \param desired_layouts The desired layouts for some operators.
* \return The Pass.
* \note Operates only on dataflow blocks. ConvertToDataflow may need to be called first.
*/
TVM_DLL Pass ConvertLayout(Map<String, Array<String>> desired_layouts);
/*!
* \brief A pass that converts consecutive dataflow operations
* inside binding blocks into dataflow blocks.
* \param min_size The minimum number of consecutive dataflow bindings
* required for the pass to create a new dataflow block
* \return The Pass.
*/
TVM_DLL Pass ConvertToDataflow(int min_size = 2);
/*!
* \brief Dead code elimination.
* \sa RemoveAllUnused
* Currently it removes:
* 1. Unused local VarBindings
* (those where the bound var is unused and no impure operation is used).
* 2. Unused Relax functions in the module.
* We detect the call chain from the entry function, and remove all unused functions.
*
* Any binding blocks that are left empty will be removed by the normalizer.
*
* \param entry_functions Names of functions that should be considered
* as entry points, in addition to any externally exposed functions.
*
* \return The Pass.
*/
TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions = {});
/*!
* \brief Pass that changes calls to operators that can be done in-place
* (generally, these are elementwise operations) in dataflow blocks into in-place implementations.
* Supported operators will be replaced by calls to `call_tir_inplace` that invoke in-place
* PrimFunc implementations of those operators (which are based on the legalizations of those
* operators).
* \note ConvertToDataflow may need to be called first to provide dataflow blocks.
* \return The pass.
*/
TVM_DLL Pass DataflowUseInplaceCalls();
/*!
* \brief Automatic mixed precision pass. Currently the pass assumes the input module to be fp32
* only, and will automatically cast fp32 to fp16 for certain ops.
* \param out_dtype The output data type of gemm/conv, which is the data type of the accumulator.
* \param fp16_input_names The names of function parameters whose dtype should become fp16. The
* function signature would change accordingly.
* \return The Pass.
*
* \note Mainly operates within dataflow blocks. ConvertToDataflow may need to be called first.
*/
TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype,
Optional<Array<String>> fp16_input_names = NullOpt);
/*!
* \brief Rewrite a Relax module for executing with CUDA graph. This pass identifies
* the regions that can be executed with CUDA graph and lifts them into new functions for runtime
* graph capturing.
*/
TVM_DLL Pass RewriteCUDAGraph();
/*!
* \brief The pass is designed for few shot tuning for static shape PrimFuncs. It examines all the
* blocks within the PrimFunc and conducts loop fusion, splitting, and other transformations based
* on MetaSchedule schedule rules but directly samples from the search space instead of using the
* tuning algorithm. User can specify the number of valid counts to try and whether to use runner
* for benchmarking.
* \param valid_count The number of valid counts to try.
* \param benchmark Whether to use runner for benchmarking.
* \return The Pass.
*/
TVM_DLL Pass FewShotTuning(int valid_count, bool benchmark);
} // namespace transform
} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_TRANSFORM_H_