blob: b27f030e459a297b165988f46d69e1d01c070a10 [file] [log] [blame]
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
"""The set of passes for Relay.
Exposes an interface for configuring the passes and
scripting them in Python.
"""
from . import _ir_pass
from . import _make
from .expr import Expr
from .ty import Type
def post_order_visit(expr, fvisit):
"""Recursively visit the ir in post DFS order node,
apply fvisit. Each node is guaranteed to be visited
only once.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
fvisit : function
The visitor function to be applied.
"""
return _ir_pass.post_order_visit(expr, fvisit)
def infer_type(expr, mod=None):
"""Infer the type of expr under the context of mod.
Parameters
----------
expr: tvm.relay.Expr
The input expression.
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
checked_expr : tvm.relay.Expr
The checked expression.
"""
return _ir_pass.infer_type(expr, mod)
def backward_fold_scale_axis(expr):
"""Backward fold axis scaling into weights of conv2d/dense.
Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.
Returns
-------
folded_expr : tvm.relay.Expr
The folded expression after transformation.
Note
----
It is recommended to call backward_fold_scale_axis
before using forward_fold_scale_axis.
As backward folding targets common conv-bn pattern.
"""
return _ir_pass.backward_fold_scale_axis(expr)
def forward_fold_scale_axis(expr):
"""Fold the scaling of axis into weights of conv2d/dense.
Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.
Returns
-------
folded_expr : tvm.relay.Expr
The folded expression after transformation.
Note
----
It is recommended to call backward_fold_scale_axis
before using forward_fold_scale_axis.
As backward folding targets common conv-bn pattern.
"""
return _ir_pass.forward_fold_scale_axis(expr)
def well_formed(expr):
"""Check that each Var is only bound once (well formed).
Parameters
----------
expr: tvm.relay.Expr
The input expression
Returns
-------
well_form : bool
Whether the input expression is well formed
"""
return _ir_pass.well_formed(expr)
def check_kind(t, mod=None):
"""Check that the type is well kinded.
For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes.
Parameters
----------
t : tvm.relay.Type
The type to check
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
well_kinded : bool
whether the input type is well kinded.
Examples
--------
.. code:: python
assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)]))
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)]))
"""
if mod is not None:
return _ir_pass.check_kind(t, mod)
else:
return _ir_pass.check_kind(t)
def free_vars(expr):
"""Get free Vars from expression expr in Post DFS order.
Parameters
----------
expr: tvm.relay.Expr
The input expression
Returns
-------
free : List[tvm.relay.Var]
The list of free variables in post DFS order.
Note
----
The fact that Vars are post-DFS ordred are useful in
neural networks: usually this means weights of previous
are ordered first.
"""
return _ir_pass.free_vars(expr)
def bound_vars(expr):
"""Get bound vars from expression expr in post-DFS order.
Parameters
----------
expr: tvm.relay.Expr
The input expression
Returns
-------
free : List[tvm.relay.Var]
The list of bound variables in post-DFS order.
"""
return _ir_pass.bound_vars(expr)
def all_vars(expr):
"""Get all vars from expression expr in post-DFS order.
Parameters
----------
expr: tvm.relay.Expr
The input expression
Returns
-------
free : List[tvm.relay.Var]
The list of all variables in post-DFS order.
"""
return _ir_pass.all_vars(expr)
def free_type_vars(expr):
"""Get free type variables from expression/type e
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
Returns
-------
free : List[tvm.relay.TypeVar]
The list of free type variables in post-DFS order
"""
return _ir_pass.free_type_vars(expr)
def bound_type_vars(expr):
"""Get bound type variables from expression/type e
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
Returns
-------
free : List[tvm.relay.TypeVar]
The list of bound type variables in post-DFS order
"""
return _ir_pass.bound_type_vars(expr)
def all_type_vars(expr):
"""Get all type variables from expression/type e
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
Returns
-------
free : List[tvm.relay.TypeVar]
The list of all type variables in post-DFS order
"""
return _ir_pass.all_type_vars(expr)
def simplify_inference(expr):
""" Simplify the data-flow graph for inference phase.
Parameters
----------
e: tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with some simplification
"""
return _ir_pass.simplify_inference(expr)
def canonicalize_ops(expr):
""" Canonicalize special operators to basic operators.
This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.)
Parameters
----------
e: tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
An expression without bias_add
"""
return _ir_pass.canonicalize_ops(expr)
def dead_code_elimination(expr):
""" Remove expressions which does not effect the program result (dead code).
Parameters
----------
e: tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
"""
return _ir_pass.dead_code_elimination(expr)
def alpha_equal(lhs, rhs):
"""Compare two Relay expr for structural equivalence (alpha equivalence).
Parameters
----------
lhs: tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))
def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs: tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_make._graph_equal(lhs, rhs))
def structural_hash(value):
"""Hash a Relay expression structurally.
Parameters
----------
expr: tvm.relay.Expr or tvm.relay.Type
The expression to hash.
Returns
-------
result: int
The hash value
"""
if isinstance(value, Expr):
return int(_ir_pass._expr_hash(value))
elif isinstance(value, Type):
return int(_ir_pass._type_hash(value))
else:
msg = ("found value of type {0} expected" +
"relay.Expr or relay.Type").format(type(value))
raise TypeError(msg)
def fold_constant(expr):
"""Fold the constant expression in expr.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
transformed_expr : tvm.relay.Expr
The transformed expression.
"""
return _ir_pass.FoldConstant(expr)
def fuse_ops(expr, opt_level=1):
"""Fuse operators in expr together.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
opt_level : int
The level of fuse optimization.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result.
"""
return _ir_pass.FuseOps(expr, opt_level)
def combine_parallel_conv2d(expr):
"""Fold multiple conv2d into one.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression
"""
return _ir_pass.CombineParallelConv2D(expr)
def alter_op_layout(expr):
"""Alternate the layouts of operators or replace primitive operators with
other expressions.
This pass can be used for computing convolution in custom layouts or
other general weight pre-transformation.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression with alternated layout.
"""
return _ir_pass.AlterOpLayout(expr)
def rewrite_annotated_ops(expr, fallback_device):
"""Rewrite the annotated program where annotation operators, e.g.
`on_deivce`, mark which device an expression should be scheduled to.
This pass helps heterogeneous execution where different operators may need
to be allocated on various devices.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
fallback_device : int
The fallback device type. It is also used as the default device for
operators with no annotated device.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression with cross device data copy operators.
"""
return _ir_pass.RewriteDeviceAnnotation(expr, fallback_device)
def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret : Dict[tvm.relay.expr, int]
A dictionary mapping tvm.relay.Expr to device type.
"""
return _ir_pass.CollectDeviceInfo(expr)
def collect_device_annotation_ops(expr):
"""Collect the device annotation ops for the given expression.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret : Dict[tvm.relay.expr, int]
A dictionary mapping tvm.relay.Expr to device type where the keys are
annotation expressions.
"""
return _ir_pass.CollectDeviceAnnotationOps(expr)
def to_anf(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
The scope of any non root expression is the least common ancestor of all it's scope.
Values are ordered by post-DFS order in each scope.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_anf(expr, mod)
def gradient(expr, mod=None):
"""
Transform a function to return original result paired with gradient of input.
Parameters
----------
expr : tvm.relay.Expr
The input expression, which is a Function or a GlobalVar.
mod : Optional[tvm.relay.Module]
Returns
-------
expr : tvm.relay.Expr
The output expression.
"""
return _ir_pass.first_order_gradient(expr, mod)