blob: 17995bfa785009a99ac9cc9073ffb3a8874855d3 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-argument
"""Common pass infrastructure across IR variants."""
import types
import inspect
import functools
import tvm._ffi
import tvm.runtime
from . import _ffi_transform_api
@tvm._ffi.register_object("transform.PassInfo")
class PassInfo(tvm.runtime.Object):
"""The class contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis.
This class can be extended by adding new members when more meta data is
needed.
Parameters
----------
opt_level : int
The optimization level of this pass.
name : str
The pass name.
required : List[str]
The list of passes that are required by a certain pass.
"""
def __init__(self, opt_level, name, required=None):
self.__init_handle_by_constructor__(_ffi_transform_api.PassInfo, opt_level, name, required)
@tvm._ffi.register_object("transform.PassContext")
class PassContext(tvm.runtime.Object):
"""The basis where a Relay optimization/analysis runs on.
Each pass context contains a number of auxiliary information that is used
to help an optimization pass. Such information includes the error reporter
to record the errors of during the optimization, etc.
opt_level : Optional[int]
The optimization level of this pass.
required_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are required by a certain pass.
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are disabled.
instruments : Optional[Sequence[PassInstrument]]
The list of pass instrument implementations.
config : Optional[Dict[str, Object]]
Additional configurations for specific passes.
"""
def __init__(
self,
opt_level=2,
required_pass=None,
disabled_pass=None,
instruments=None,
config=None,
):
required = list(required_pass) if required_pass else []
if not isinstance(required, (list, tuple)):
raise TypeError("required_pass is expected to be the type of " + "list/tuple/set.")
disabled = list(disabled_pass) if disabled_pass else []
if not isinstance(disabled, (list, tuple)):
raise TypeError("disabled_pass is expected to be the type of " + "list/tuple/set.")
instruments = list(instruments) if instruments else []
if not isinstance(instruments, (list, tuple)):
raise TypeError("instruments is expected to be the type of " + "list/tuple/set.")
config = config if config else None
self.__init_handle_by_constructor__(
_ffi_transform_api.PassContext, opt_level, required, disabled, instruments, config
)
def __enter__(self):
_ffi_transform_api.EnterPassContext(self)
return self
def __exit__(self, ptype, value, trace):
_ffi_transform_api.ExitPassContext(self)
def override_instruments(self, instruments):
"""Override instruments within this PassContext.
If there are existing instruments, their ``exit_pass_ctx`` callbacks are called.
Then switching to new instruments and calling new ``enter_pass_ctx`` callbacks.
instruments : Sequence[PassInstrument]
The list of pass instrument implementations.
"""
_ffi_transform_api.OverrideInstruments(self, instruments)
@staticmethod
def current():
"""Return the current pass context."""
return _ffi_transform_api.GetCurrentPassContext()
@staticmethod
def list_configs():
"""List all registered `PassContext` configuration names and metadata.
Returns
-------
configs : Dict[str, Dict[str, str]]
"""
return _ffi_transform_api.ListConfigs()
@tvm._ffi.register_object("transform.Pass")
class Pass(tvm.runtime.Object):
"""The base class of all passes. All methods here are just simple wrappers
that are implemented in the backend. They are defined for users to
conveniently interact with the base class.
"""
@property
def info(self):
"""Get the pass meta."""
return _ffi_transform_api.Info(self)
def __call__(self, mod):
"""Execute the pass. Note that for sequential pass, the dependency among
different passes will be resolved in the backend.
Parameters
----------
mod : tvm.IRModule
The module that a certain optimization is performed on.
Returns
-------
mod : tvm.IRModule
The updated module after applying this pass.
"""
return _ffi_transform_api.RunPass(self, mod)
@tvm._ffi.register_object("transform.ModulePass")
class ModulePass(Pass):
"""A pass that works on tvm.IRModule. Users don't need to interact with
this class directly. Instead, a module pass should be created through
`module_pass`, because the design of the `module_pass` API is flexible
enough to handle the creation of a module pass in different manners. In
addition, all members of a module pass can be accessed from the base class.
The same rule applies to FunctionPass as well.
"""
@tvm._ffi.register_object("transform.Sequential")
class Sequential(Pass):
"""A pass that works on a sequence of pass objects. Multiple passes can be
executed sequentially using this class.
Note that users can also provide a series of passes that they don't want to
apply when running a sequential pass. Pass dependency will be resolved in
the backend as well.
Parameters
----------
passes : Optional[List[Pass]]
A sequence of passes candidate for optimization.
opt_level : Optional[int]
The optimization level of this sequential pass.
The opt_level of a default sequential pass is set to 0.
Note that some of the passes within the Sequantial may still not be executed
if their opt_level is higher than the provided opt_level.
name : Optional[str]
The name of the sequential pass.
required : Optional[List[str]]
The list of passes that the sequential pass is dependent on.
"""
def __init__(self, passes=None, opt_level=0, name="sequential", required=None):
passes = passes if passes else []
if not isinstance(passes, (list, tuple)):
raise TypeError("passes must be a list of Pass objects.")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of list/tuple.")
self.__init_handle_by_constructor__(
_ffi_transform_api.Sequential, passes, opt_level, name, required
)
def _wrap_class_module_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
class PyModulePass(ModulePass):
"""Internal wrapper class to create a class instance."""
def __init__(self, *args, **kwargs):
# initialize handle in cass pass_cls creation failed.fg
self.handle = None
inst = pass_cls(*args, **kwargs)
# it is important not to capture self to
# avoid a cyclic dependency
def _pass_func(mod, ctx):
return inst.transform_module(mod, ctx)
self.__init_handle_by_constructor__(
_ffi_transform_api.MakeModulePass, _pass_func, pass_info
)
self._inst = inst
def __getattr__(self, name):
# fall back to instance attribute if there is not any
return self._inst.__getattribute__(name)
functools.update_wrapper(PyModulePass.__init__, pass_cls.__init__)
PyModulePass.__name__ = pass_cls.__name__
PyModulePass.__doc__ = pass_cls.__doc__
PyModulePass.__module__ = pass_cls.__module__
return PyModulePass
def module_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Decorate a module pass.
This function returns a callback when pass_func is provided.
Otherwise, it serves a decorator function.
pass_func can also be a class type with a method transform_module.
This function will create a decorated ModulePass using transform_module
as the pass function.
Parameters
----------
pass_func : Optional[Callable[(Module, PassContext) ->Module]]
The transformation function or class.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the module pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_module_pass : Union[Callable, ModulePass]
A decorator will be returned if pass_func is not provided,
otherwise return the decorated result.
The returned decorator has two behaviors depending on the input:
A new ModulePass will be returned when we decorate a pass function.
A new ModulePass class will be returned when we decorate a class type.
Examples
--------
The following code block decorates a module pass class.
.. code-block:: python
@relay.transform.module_pass
class CustomPipeline:
def __init__(self, enable_fold):
self.enable_fold = enable_fold
self.cse = relay.transform.EliminateCommonSubexpr()
self.const_fold = relay.transform.FoldConstant()
def transform_module(self, mod, ctx):
mod = self.cse(mod, ctx)
if self.enable_fold:
mod = self.const_fold(mod, ctx)
return mod
# create an instance of customized pipeline
pipeline = CustomPipeline(enable_fold=False)
assert isinstance(pipeline, transform.ModulePass)
# run the pipeline.
output_module = pipeline(input_module)
The following code creates a module pass by decorating
a user defined transform function.
.. code-block:: python
@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), "float32")
x = relay.var("x", tp)
gv = relay.GlobalVar("var")
func = relay.Function([x], relay.abs(x))
new_mod = tvm.IRModule({gv: func})
new_mod.update(mod)
return new_mod
module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
"""
if opt_level is None:
raise ValueError("Please provide opt_level for the module pass.")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " + "list/tuple.")
def create_module_pass(pass_arg):
"""Internal function that creates a module pass"""
fname = name if name else pass_arg.__name__
info = PassInfo(opt_level, fname, required)
if inspect.isclass(pass_arg):
return _wrap_class_module_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _ffi_transform_api.MakeModulePass(pass_arg, info)
if pass_func:
return create_module_pass(pass_func)
return create_module_pass
def PrintIR(header="", show_meta_data=False):
"""A special trace pass that prints the header and IR.
Parameters
----------
header : str
The header to be displayed along with the dump.
show_meta_data : bool
A boolean flag to indicate if meta data should be printed.
Returns
--------
The pass
"""
return _ffi_transform_api.PrintIR(header, show_meta_data)