blob: cdea8e8e3c235d4fda9303af94ae01fc5a416ca6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/transform.h
* \brief Relay specific transformation passes.
*/
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
#include <tvm/ir/transform.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/target/compilation_config.h>
#include <tvm/target/target.h>
#include <tvm/target/virtual_device.h>
#include <string>
namespace tvm {
namespace relay {
namespace transform {
using Pass = tvm::transform::Pass;
using PassNode = tvm::transform::PassNode;
using PassInfo = tvm::transform::PassInfo;
using PassInfoNode = tvm::transform::PassInfoNode;
using PassContext = tvm::transform::PassContext;
using PassContextNode = tvm::transform::PassContextNode;
using Sequential = tvm::transform::Sequential;
/*
* \brief Create a function pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the function pass.
* \param name The name of the function pass.
* \param required The list of the passes that the function pass is dependent on.
*
* \return The created function pass.
*/
TVM_DLL Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level, String name, tvm::Array<String> required);
/*! \brief Remove let-bound expressions which do not effect the program result.
*
* This pass will remove let bindings which are not referenced. If inline_once is True,
* let bindings which are only referenced once will also be inlined.
*
* For example, this pass should turn `let a = 1; 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1; a` will be optimized into 1 if inline_once is True.
*
* If ignore_purity is False, possibly side-effecting expressions (such as memory allocation,
* random number generation, reading/writing references, or calls to primitive or external
* functions) are never elided or inlined. This is sound, but ignore_purity can be set to True
* to suppress this check.
*
* The analysis is fairly conservative, for example it assumes all local functions
* may be called more than once, any functions passed as arguments have side effects,
* and so on.
*
* \param inline_once whether or not to inline bindings used exactly once.
* \param ignore_purity whether to ignore whether expressions have side-effects
*
* \return the pass.
*/
TVM_DLL Pass DeadCodeElimination(bool inline_once = false, bool ignore_purity = false);
/*!
* \brief Convert all expressions of TensorType into GradCell,
* an algebraic data type defined in gradient.rly.
*
* This will delay or decrease memory usage. All calls to
* ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory,
* rather only instantiate if needed. It also defines + and * operation
* between GradCell types which can increase performance when using
* zero-filled or one-filled tensors, which is the case in reverse mode ad.
*
* \return the pass
*/
TVM_DLL Pass LazyGradientInit();
/*!
* \brief Fold constant expressions.
*
* Because of backward compatibility reason it skips QNN primitives from folding by default.
* There are some transformation passes like FakeQuantizationToInteger, which requires to keep QNN
* primitives for constant subgraphs. Uncontrolled constant folding of QNN primitives may break
* applicability of FakeQuantizationToInteger. We suggest to use FoldConstant pass with none
* default fold_qnn=True value only when all other QNN sensitive passes were already applied.
*
* \param fold_qnn Whether to fold constants for QNN operations.
*
* \return The pass.
*/
TVM_DLL Pass FoldConstant(bool fold_qnn = false);
/*!
* \brief Split function with huge number of arguments to smaller pieces.
*
* \return The pass.
*/
TVM_DLL Pass SplitArgs(int max_function_args);
/*!
* \brief Fuse operations into expr into separate functions.
*
* \param fuse_opt_level Optimization level. If it is -1 it will be inferred from pass context.
*
* \return The pass.
*/
TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
/*!
* \brief The inverse operation of FuseOps. It transforms a fused program returned by
* FuseOps into the program before FuseOps. (i.e. x == DefuseOps(FuseOps(x)))
*
* \return The pass.
*/
TVM_DLL Pass DefuseOps();
/*!
* \brief Rewrite the annotated program.
*
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
*
* \return The pass.
*/
TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);
/*!
* \brief Turn an expression to Basic Block Normal Form.
*
* We define a block as a group of expressions implied by the scope structure.
*
* Each graph node can only belong to a single block.
*
* For any value that is being used in multiple blocks, it has to be referred
* by a Var which is defined in a block, whose scope is the least common ancestor
* of blocks this value is used.
*
* \return The pass.
*/
TVM_DLL Pass ToBasicBlockNormalForm();
/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form).
*
* The scope of the root expression is the global scope.
*
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \return The pass.
*/
TVM_DLL Pass ToANormalForm();
/*!
* \brief ToANormalForm but on incomplete graph.
*
* \param expr the graph.
*
* \return The transformed program.
*/
TVM_DLL Expr ToANormalForm(const Expr& expr);
/*!
* \brief Turn an expression into continuation passing style(CPS).
*
* CPS mean that every function will, instead of returning the result directly,
* be passed down an extra function (called the continuation) as argument,
* and pass the result to the continuation instead.
*
* Thus, every function call has to be passed an extra argument
* that represent the rest of the computation (Hence the name of continuation).
*
* Similarly, all other compute will be wrapped and call the continuation as well.
*
* \return the pass.
*/
TVM_DLL Pass ToCPS();
/*!
* \brief Remove let binding and directly share via pointer instead.
*
* It will remove all let binding,
* and turn all of the variable bound by let into direct pointer reference.
*
* \return the expression in graph normal form.
*/
TVM_DLL Pass ToGraphNormalForm();
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
*
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \return the optimized expression.
*/
TVM_DLL Pass PartialEval();
/*!
* \brief Simplify certain operators during inference. For example, the result
* of a batch norm which is indexed at tuple index 0 will be unpacked into a
* number of simplified operators.
*
* \return The Pass.
*/
TVM_DLL Pass SimplifyInference();
/*!
* \brief Replaces non linear activation functions with their fast but approximate counterparts.
*
* \return The Pass.
*/
TVM_DLL Pass FastMath();
/*!
* \brief Find Dynamic ops and make them static
*
* Searches the graph for dynamic ops. If the dynamic inputs to those ops are constants, it replaces
* them with static ops and re-performs type inference and constant folding. The pass repeats
* itself until the graph stops changing or we run too many iterations.
*
* \return The pass.
*/
TVM_DLL Pass DynamicToStatic();
/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambiguous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \return The pass.
*/
TVM_DLL Pass InferType();
/*!
* \brief Infer the type of an expression, reusing existing type information.
*
* The result of type checking is a new expression with unambiguous
* type information filled in for the given node only. The local
* version can use existing type information populated throughout
* the expression and assumes this information is correct. The local
* version also avoids examining large amounts of the graph assuming
* type information is filled in properly which makes it much faster if we
* iteratively call type inference.
*
* \return The type of the expression.
*/
TVM_DLL Type InferTypeLocal(const Expr& expr);
/*!
* \brief Search and eliminate common subexpression. For example, if there are
* two expressions evaluated to an identical value, a single variable is created
* and these two expressions are replaced by this variable.
*
* \param fskip The callback argument that allows to skip certain expressions.
*
* \return The pass.
*/
TVM_DLL Pass EliminateCommonSubexpr(runtime::PackedFunc fskip = nullptr);
/*!
* \brief Combine parallel 2d convolutions into a single convolution if the
* number of branches of this conv2d operator is not less than
* `min_num_branch`.
*
* \param min_num_branches The minimun number of branches.
*
* \return The pass.
*/
TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
/*!
* \brief Combine parallel dense ops into a single batch_matmul if the
* number of branches of this dense operator is not less than
* `min_num_branch`.
*
* \param min_num_branches The minimun number of branches.
* \param to_batch_matmul Whether to combine parallel dense ops to batch matmul.
* If set false, combine dense ops to single dense op.
*
* \return The pass.
*/
TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3, bool to_batch_matmul = true);
/*!
* \brief Combine parallel batch_matmul ops into a single batch_matmul
* if the number of branches of this dense operator is not less than
* `min_num_branch`.
*
* \param min_num_branches The minimun number of branches.
*
* \return The pass.
*/
TVM_DLL Pass CombineParallelBatchMatmul(uint64_t min_num_branches = 3);
/*!
* \brief Backward fold axis scaling into weights of conv/dense operators.
*
* \return The pass.
*/
TVM_DLL Pass BackwardFoldScaleAxis();
/*!
* \brief Forward fold axis scaling into weights of conv/dense operators.
*
* \return The pass.
*/
TVM_DLL Pass ForwardFoldScaleAxis();
/*!
* \brief A sequential pass that executes ForwardFoldScaleAxis and
* BackwardFoldScaleAxis passes.
*
* \return The pass.
*/
TVM_DLL Pass FoldScaleAxis();
/*!
* \brief Canonicalize some operators to the simplified operators. For example,
* bias_add can be canonicalized to expand_dims and broadcast_add.
*
* \return The pass.
*/
TVM_DLL Pass CanonicalizeOps();
/*!
* \brief Alternate the layouts of operators or replace primitive operators
* with other expressions.
*
* \return The pass.
*/
TVM_DLL Pass AlterOpLayout();
/*!
* \brief Do layout rewrite according to the tile structure created by auto-scheduler.
* \return The pass
*/
TVM_DLL Pass AutoSchedulerLayoutRewrite();
/*!
* \brief Do layout rewrite according to the tile structure created by meta-schedule.
* \return The pass
*/
TVM_DLL Pass MetaScheduleLayoutRewrite();
/*!
* \brief Given a dest layout, this pass transforms the expr such that most of the ops input data
* layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one
* at the start and one at the end.
*
* This pass is not a part of relay.build and is expected to be called between framework-relay
* parser and relay.build call. This is very helpful for hardware backends that support/prefer only
* type of data layout.
*
* RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009
*
* This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can define new
* layouts for conv2d ops for now. Most of the other operators try to adapt to their input layout
* using the InferCorrectLayout infrastructure.
*
* \param desired_layouts Specify mapping of op_name to array of desired layouts for each input.
* For example: Map("nn.conv2d", Array("NHWC", "OHWI")),
* this specifies the desired layout for data then kernel for nn.conv2d.
* \return The pass.
*/
TVM_DLL Pass ConvertLayout(const Map<String, Array<String>>& desired_layouts);
/*!
* \brief Legalizes an expr with another expression.
* \param legalize_map_attr_name The Op's attr name which corresponds to the legalize rule function.
* One can collect and isolate similar type of legalize transformations using this param. For
* example, transformations that only apply to Dialects can be isolated into a FTVMDialectLegalize
* string. This pass calls only those transformations that have been registered using the supplied
* legalize_map_attr_name.
*
* \return The pass.
*/
TVM_DLL Pass Legalize(const String& legalize_map_attr_name = "FTVMLegalize");
/*!
* \brief Canonicalize cast expressions to make operator fusion more efficient.
*
* \return The pass.
*/
TVM_DLL Pass CanonicalizeCast();
/*!
* \brief Add abstraction over a constructor or global variable bound to a function.
*
* For example: `square` is transformed to
* `fn (%x: int32) -> int32 { square(x) }`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
* for more details.
*
* \param expand_constructor Whether to expand constructors.
* \param expand_global_var Whether to expand global variables.
*
* \return The pass.
*/
TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
/*!
* \brief Partition a Relay program into regions that can be executed on
* different backends.
*
* \return The pass.
*/
TVM_DLL Pass PartitionGraph();
/*!
* \brief Inline the global functions marked as `inline` in a given Relay
* IRModule.
*
* \return The pass.
*/
TVM_DLL Pass Inline();
/*!
* \brief Remove the unused functions in the Relay IRModule.
*
* \param entry_functions The entry functions used to search the functions that
* are being used.
*
* \return The pass.
*/
TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
/*!
* \brief Simplify the Relay expression.
*
* \return The pass.
*/
TVM_DLL Pass SimplifyExpr();
/*!
* \brief Run any custom passes registered under "RelayToTIR" attributes on TargetKinds.
*
* This pass looks for inline, let-bound or global functions which have a "Compiler" attribute.
* If the attribute value corresponds to a TargetKind with a "RelayToTIR" attribute, then the
* 'custom' pass bound to that attribute is run (at most once) on the IRModule as a whole.
*
* If, in addition, the \p config has a Target with a matching TargetKind, that Target is set
* as the 'current' target before the custom pass is executed. In this way it is possible
* for custom passes to pick up target options which may guide how they transform the IRModule.
* (Those targets are referred to as 'extern codegen targets' elsewhere).
*
* A typical custom pass will:
* - Find calls to "Compiler" attributes functions with matching compiler name.
* - Lower those function to TIR PrimFuncs.
* - Bind those functions into the IRModule under the the functions' "global_symbol" attribute.
* - Replace all calls to those functions with 'call_lowered' to the matching global.
* Care should be taken to handle multiple calls to the same function.
* See src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc for an example custom pass.
*
* It is also possible (despite the pass and attribute names!) for the custom pass to proceed
* directly to a runtime::Module, which can be attached to the output IRModules "external_mods"
* attribute (taking care not to clobber any existing modules). In this case the flow is as above,
* except:
* - The runtime::Module must contain a binding for each compiled function under their
* "global_symbol" (ie runtime::Module::ImplementsFunction should return true).
* - A Relay Function must be bound (or re-bound) into the result IRModule, again with the same
* "global_symbol", but with only the "Extern" attribute set to Integer(1). The function body
* should be the original function body. In this way we always have a TVM definition matching
* every global function name.
*
* There are many existing runtime::Modules, ranging from source to object to dynamic libaries to
* entirely custom implementations. Some of those may require additional compilation using
* 'export_library' on the final build artifact.
*
* The OutlineCompilerFunctionsWithExistingGlobalSymbols and MarkCompilerFunctionsAsExtern utility
* passes can be used by custom passes to take care of some of the boilerplate.
*
* TODO(mbs): Rename PreLoweringTargetHooks?
*
* \param config All available targets.
*
* \return The pass.
*/
TVM_DLL Pass RelayToTIRTargetHook(CompilationConfig config);
/*!
* \brief A pass for manifesting explicit memory allocations and rewriting
* specific dialects.
*
* \param cpu_virtual_device VirtualDevice for computations and data which must reside on a CPU,
* such as shapes and shape functions.
*
* \return The pass.
*/
TVM_DLL Pass ManifestAlloc(VirtualDevice cpu_virtual_device);
/*!
* \brief A pass for manifesting variable lifetimes by inserting kill operations when variables
* become dead. This pass should be run after ManifestAlloc, and should not be run more than once.
*
* \return The pass.
*/
TVM_DLL Pass ManifestLifetimes();
/*!
* \brief Uses existing "on_device" and "device_copy" CallNodes to infer the \p VirtualDevice on
* which every Relay sub-expression should run and the result stored. Captures the result of that
* analysis using new "on_device" and "device_copy" CallNodes.
*
* See tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator}
* for help recovering the device for an arbitrary sub-expression in downstream transformations.
*
* \param config Describes the targets and default \p VirtualDevice for all primitive operators and
* host sub-expressions.
*
* \return The pass.
*/
TVM_DLL Pass PlanDevices(CompilationConfig config);
/*!
* \brief This transform flattens atrous convolution, which corresponds to the sequence of
* operations: "space_to_batch_nd"->"conv2d"->"batch_to_space_nd" and convert them into subgraphs
* with a convolution with the modified "dilation" and recalculated "padding" parameters.
*
* \return The pass.
*/
TVM_DLL Pass FlattenAtrousConv();
/*!
* \brief Annotates the minimum required memory of each primitive function callsite by analyzing
* the liveness of the input/output tensors at each function callsite and calculating the total
* amount of memory these tensors require. This is added as a "used_memory" annotation to the
* function in question as a list of the number of bytes for each callsite. In addition, the
* containing function is annotated with an "io_used_memory" annotation which refers to the total
* memory required for the IO tensors.
*
* Note: This pass does not support dynamic shapes, it is the users responsibility to check this
* pass isn't applied where dynamic shapes may be input.
*/
TVM_DLL Pass AnnotateUsedMemory();
/*!
* \brief Captures the post-dfs index and dominator post-dfs index of (most) expression nodes in
* their span, in the form "index:<post-dfs index>:<dominator post-dfs index>". This is useful for
* debugging since a) it helps identify pretty-printed sub-expressions within the overall model
* and b) the indexes are heavily used by Collage for its compact representation of sub-graphs.
*
* Note that Op and Constructor nodes are not changed even though they are assigned an
* post-dfs index.
*/
TVM_DLL Pass CapturePostDfsIndexInSpans();
/*!
* \brief Calls device dependent memory scope analysis pass, collects mapping of desirable
* expr->memory_scope and annotates expressions by VirtualDevice with required memory_scope
*/
TVM_DLL Pass AnnotateMemoryScope(CompilationConfig config);
/*!
* \brief Removes non-fused reshapes after lowering the graph.
* InferType() cannot be invoked after calling this pass as it removes reshapes from the call
* graph. Many targets only need buffer addresses irrespective of the shapes of them. This makes
* reshapes symbolic once the graph has been lowered. Reshape removal results into smaller code
* size and reduced buffer allocations. It opens up opportunities of operator fusion in the target
* backend. Thus, consequently, it improves the performance of the inference.
*/
TVM_DLL Pass RemoveStandaloneReshapes();
} // namespace transform
/*!
* \brief Bind the free variables to a Relay expression. This is a helper
* function usually called by other pass functions to help optimizations.
* If any free variables are introduced into a function, those are added
* to the functoin parameters.
* Additionally this may change the order of parameters if you map a variable
* to a variable.
*
* \param expr The input expression.
* \param binds The variable to expression map that will be used to help the
* binding.
*
* \return The updated expression.
*/
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
/*!
* \brief Substitute variables with new variables (including function parameters) in a function.
* This is a helper function usually called by other pass functions to help optimizations.
* Expects all values in the bind map to be Vars.
*
* \param func The input function.
* \param binds The variable to expression map that will be used to help the
* binding.
*
* \return The updated expression.
*/
TVM_DLL Function SubstituteBoundVars(const Function& func, const tvm::Map<Var, Expr>& binds);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order. This
* function is used as a helper function to rewrtie an expression in a pass.
*
* \param expr The expression.
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
* \return The rewritten expression.
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_attr_name,
std::function<ObjectRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order. This
* function is used as a helper function to rewrtie an expression in a pass.
*
* \param expr The expression.
* \param rewrite_func The rewrite func that will apply to all operators.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
*
* \return The rewritten expression.
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func,
std::function<ObjectRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
/*!
* \brief Rewrite the annotated program.
*
* \param expr The expression.
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
*
* \return The updated program.
*/
TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
/*!
* \brief Turn an expression into continuation passing style(CPS).
*
* CPS mean that every function will, instead of returning the result directly,
* be passed down an extra function (called the continuation) as argument,
* and pass the result to the continuation instead.
*
* Thus, every function call has to be passed an extra argument
* that represent the rest of the computation (Hence the name of continuation).
*
* Similarly, all other compute will be wrapped and call the continuation as well.
*
* \param f the function.
* \param mod the module.
*
* \return the converted Function.
*/
TVM_DLL Function ToCPS(const Function& f, const IRModule& mod);
/*!
* \brief Remove the continuation argument of a CPS function.
*
* Note that this only transform the type back into un-CPS form
* when there is no higher order input/output.
*
* \param f the function.
*
* \return the converted Function.
*/
TVM_DLL Function UnCPS(const Function& f);
/*!
* \brief Deduplicate the bound variables and type variables in the expression.
*
* \param e the expression.
*
* \return the deduplicated expression.
*/
TVM_DLL Expr DeDup(const Expr& e);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_TRANSFORM_H_