| # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name |
| """The expression nodes of Relay.""" |
| from __future__ import absolute_import |
| from numbers import Number as _Number |
| |
| import numpy as _np |
| from .base import RelayNode, register_relay_node |
| from . import _make |
| from . import _expr |
| from . import ty as _ty |
| from .._ffi import base as _base |
| from .. import nd as _nd |
| from .. import convert |
| |
| # will be registered afterwards |
| _op_make = None |
| |
| class Expr(RelayNode): |
| """The base type for all Relay expressions.""" |
| @property |
| def checked_type(self): |
| """Get the checked type of tvm.relay.Expr. |
| |
| Returns |
| ------- |
| checked_type : tvm.relay.Type |
| The checked type. |
| """ |
| ret = self._checked_type_ |
| if ret is None: |
| raise ValueError("The type checker has not populated" |
| " the checked_type for this node") |
| return ret |
| |
| def astype(self, dtype): |
| """Cast the content type of the current data to dtype. |
| |
| Parameters |
| ---------- |
| dtype : str |
| The target data type. |
| |
| Note |
| ---- |
| This function only works for TensorType Exprs. |
| |
| Returns |
| ------- |
| result : tvm.relay.Expr |
| The result expression. |
| """ |
| return _make.cast(self, dtype) |
| |
| def __add__(self, other): |
| if isinstance(other, Expr): |
| return _op_make.add(self, other) |
| elif isinstance(other, _Number): |
| raise TypeError('convert "%s" with `const` first' % str(other)) |
| else: |
| raise TypeError("type %s not supported" % str(type(other))) |
| |
| def __radd__(self, other): |
| return self.__add__(other) |
| |
| def __sub__(self, other): |
| if isinstance(other, Expr): |
| return _op_make.subtract(self, other) |
| elif isinstance(other, _Number): |
| raise TypeError('convert "%s" with `const` first' % str(other)) |
| else: |
| raise TypeError("type %s not supported" % str(type(other))) |
| |
| def __rsub__(self, other): |
| if isinstance(other, _Number): |
| raise TypeError('convert "%s" with `const` first' % str(other)) |
| else: |
| raise TypeError("type %s not supported" % str(type(other))) |
| |
| def __mul__(self, other): |
| if isinstance(other, Expr): |
| return _op_make.multiply(self, other) |
| elif isinstance(other, _Number): |
| raise TypeError('convert "%s" with `const` first' % str(other)) |
| else: |
| raise TypeError("type %s not supported" % str(type(other))) |
| |
| def __rmul__(self, other): |
| return self.__mul__(other) |
| |
| def __div__(self, other): |
| if isinstance(other, Expr): |
| return _op_make.divide(self, other) |
| elif isinstance(other, _Number): |
| raise TypeError('convert "%s" with `const` first' % str(other)) |
| else: |
| raise TypeError("type %s not supported" % str(type(other))) |
| |
| def __rdiv__(self, other): |
| if isinstance(other, _Number): |
| raise TypeError('convert "%s" with `const` first' % str(other)) |
| else: |
| raise TypeError("type %s not supported" % str(type(other))) |
| |
| def __truediv__(self, other): |
| return self.__div__(other) |
| |
| def __rtruediv__(self, other): |
| return self.__rdiv__(other) |
| |
| |
| @register_relay_node |
| class Constant(Expr): |
| """A constant expression in Relay. |
| |
| Parameters |
| ---------- |
| data : tvm.nd.NDArray |
| The data content of the constant expression. |
| """ |
| def __init__(self, data): |
| self.__init_handle_by_constructor__(_make.Constant, data) |
| |
| |
| @register_relay_node |
| class Tuple(Expr): |
| """Tuple expression that groups several fields together. |
| |
| Parameters |
| ---------- |
| fields : List[tvm.relay.Expr] |
| The fields in the tuple. |
| """ |
| def __init__(self, fields): |
| self.__init_handle_by_constructor__(_make.Tuple, fields) |
| |
| def __getitem__(self, index): |
| if index >= len(self): |
| raise IndexError("Tuple index out of range") |
| return self.fields[index] |
| |
| def __len__(self): |
| return len(self.fields) |
| |
| def astype(self, _): |
| raise TypeError("astype cannot be used on tuple") |
| |
| |
| @register_relay_node |
| class Var(Expr): |
| """A local variable in Relay. |
| |
| Local variable can be used to declare input |
| arguments to a function, or intermediate variables. |
| |
| Parameters |
| ---------- |
| name_hint: str |
| The name of the variable. |
| This name only acts as a hint, and is not used |
| for equality. |
| |
| type_annotation: tvm.relay.Type, optional |
| The type annotation on the variable. |
| """ |
| def __init__(self, name_hint, type_annotation=None): |
| self.__init_handle_by_constructor__( |
| _make.Var, name_hint, type_annotation) |
| |
| @property |
| def name_hint(self): |
| """Get name hint of the current var.""" |
| name = self.vid.name_hint |
| return name |
| |
| |
| @register_relay_node |
| class GlobalVar(Expr): |
| """A global variable in Tvm.Relay. |
| |
| GlobalVar is used to refer to the global functions |
| stored in the module. |
| |
| Parameters |
| ---------- |
| name_hint: str |
| The name of the variable. |
| """ |
| def __init__(self, name_hint): |
| self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) |
| |
| def __call__(self, *args): |
| """Invoke the gobal function. |
| |
| Parameters |
| ---------- |
| args: List[relay.Expr] |
| Arguments. |
| """ |
| return Call(self, args, None, None) |
| |
| |
| @register_relay_node |
| class Function(Expr): |
| """A function declaration expression. |
| |
| Parameters |
| ---------- |
| params: List[tvm.relay.Var] |
| List of input parameters to the function. |
| |
| body: tvm.relay.Expr |
| The body of the function. |
| |
| ret_type: Optional[tvm.relay.Type] |
| The return type annotation of the function. |
| |
| type_params: Optional[List[tvm.relay.TypeParam]] |
| The additional type parameters, this is only |
| used in advanced usecase of template functions. |
| """ |
| def __init__(self, |
| params, |
| body, |
| ret_type=None, |
| type_params=None, |
| attrs=None): |
| if type_params is None: |
| type_params = convert([]) |
| |
| self.__init_handle_by_constructor__( |
| _make.Function, params, body, ret_type, type_params, attrs) |
| |
| def __call__(self, *args): |
| """Invoke the global function. |
| |
| Parameters |
| ---------- |
| args: List[relay.Expr] |
| Arguments. |
| """ |
| return Call(self, args, None, None) |
| |
| |
| @register_relay_node |
| class Call(Expr): |
| """Function call node in Relay. |
| |
| Call node corresponds the operator application node |
| in computational graph terminology. |
| |
| Parameters |
| ---------- |
| op: tvm.relay.Op or any tvm.relay.Expr with function type. |
| The operation to be called. |
| |
| args: List[tvm.relay.Expr] |
| The arguments to the call. |
| |
| attrs: Optional[tvm.Attrs] |
| Attributes to the call, can be None |
| |
| type_args: Optional[List[tvm.relay.Type]] |
| The additional type arguments, this is only |
| used in advanced usecase of template functions. |
| """ |
| def __init__(self, op, args, attrs=None, type_args=None): |
| if not type_args: |
| type_args = [] |
| self.__init_handle_by_constructor__( |
| _make.Call, op, args, attrs, type_args) |
| |
| |
| @register_relay_node |
| class Let(Expr): |
| """Let variable binding expression. |
| |
| Parameters |
| ---------- |
| variable: tvm.relay.Var |
| The local variable to be bound. |
| |
| value: tvm.relay.Expr |
| The value to be bound. |
| |
| body: tvm.relay.Expr |
| The body of the let binding. |
| """ |
| def __init__(self, variable, value, body): |
| self.__init_handle_by_constructor__( |
| _make.Let, variable, value, body) |
| |
| |
| @register_relay_node |
| class If(Expr): |
| """A conditional expression in Relay. |
| |
| Parameters |
| ---------- |
| cond: tvm.relay.Expr |
| The condition. |
| |
| true_branch: tvm.relay.Expr |
| The expression evaluated when condition is true. |
| |
| false_branch: tvm.relay.Expr |
| The expression evaluated when condition is false. |
| """ |
| def __init__(self, cond, true_branch, false_branch): |
| self.__init_handle_by_constructor__( |
| _make.If, cond, true_branch, false_branch) |
| |
| |
| @register_relay_node |
| class TupleGetItem(Expr): |
| """Get index-th item from a tuple. |
| |
| Parameters |
| ---------- |
| tuple_value: tvm.relay.Expr |
| The input tuple expression. |
| |
| index: int |
| The index. |
| """ |
| def __init__(self, tuple_value, index): |
| self.__init_handle_by_constructor__( |
| _make.TupleGetItem, tuple_value, index) |
| |
| |
| class TempExpr(Expr): |
| """Baseclass of all TempExpr. |
| |
| TempExprs are pass specific expression that can be |
| useful to define intermediate result in the |
| rewriting pass such as layout or type transformation. |
| """ |
| def realize(self): |
| """Convert the expression to a normal(non-temp) Expr. |
| |
| Returns |
| ------- |
| The corresponding normal expression. |
| """ |
| return _expr.TempExprRealize(self) |
| |
| |
| class TupleWrapper(object): |
| """TupleWrapper. |
| |
| This class is a Python wrapper for a Relay tuple of known size. |
| It allows for accessing the fields of the Relay tuple as though |
| it were a Python tuple. |
| |
| Parameters |
| ---------- |
| tuple_value: tvm.relay.Expr |
| The input tuple |
| |
| size: int |
| The size of the tuple. |
| """ |
| def __init__(self, tuple_value, size): |
| self.tuple_value = tuple_value |
| self.size = size |
| |
| def astuple(self): |
| """Returns the underlying Relay tuple if this wrapper is passed |
| as an argument to an FFI function.""" |
| return self.tuple_value |
| |
| def astext(self): |
| """Get the text format of the tuple expression. |
| |
| Returns |
| ------- |
| text : str |
| The text format of the tuple expression. |
| """ |
| return self.tuple_value.astext() |
| |
| def __getitem__(self, index): |
| if index >= len(self): |
| raise IndexError("Tuple index out of range") |
| return TupleGetItem(self.tuple_value, index) |
| |
| def __len__(self): |
| return self.size |
| |
| def __repr__(self): |
| return ("TupleWrapper(" + self.tuple_value.__repr__() + |
| ", " + str(self.size) + ")") |
| |
| def astype(self, _): |
| raise TypeError("astype cannot be used on tuple") |
| |
| |
| def var(name_hint, |
| type_annotation=None, |
| shape=None, |
| dtype="float32"): |
| """Create a new tvm.relay.Var. |
| |
| This is a simple wrapper function that allows specify |
| shape and dtype directly. |
| |
| Parameters |
| ---------- |
| name_hint: str |
| The name of the variable. |
| This name only acts as a hint, and is not used |
| for equality. |
| |
| type_annotation: Optional[tvm.relay.Type, str] |
| The type annotation on the variable. |
| When type_annotation is a str, we will create a scalar variable. |
| |
| shape: Optional[List[tvm.Expr]] |
| The shape of the tensor type. |
| |
| dtype: str, optional |
| The data type of the tensor. |
| |
| Examples |
| -------- |
| .. code-block:: python |
| |
| # The following 4 lines are equivalent to each other |
| x = tvm.relay.Var("x", tvm.relay.TensorType([1, 2])) |
| x = tvm.relay.var("x", tvm.relay.TensorType([1, 2])) |
| x = tvm.relay.var("x", shape=[1, 2]) |
| x = tvm.relay.var("x", shape=[1, 2], dtype="float32") |
| |
| # The following 2 lines are equivalent to each other. |
| y = tvm.relay.var("x", "float32") |
| y = tvm.relay.var("x", shape=(), dtype="float32") |
| """ |
| if type_annotation is not None and shape is not None: |
| raise ValueError("Can only specify either type_annotation or shape.") |
| if shape is not None: |
| type_annotation = _ty.TensorType(shape, dtype) |
| elif isinstance(type_annotation, str): |
| type_annotation = _ty.TensorType((), type_annotation) |
| return Var(name_hint, type_annotation) |
| |
| |
| def const(value, dtype=None): |
| """Create a constant value. |
| |
| Parameters |
| ---------- |
| value: Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray] |
| The constant value. |
| |
| dtype: str, optional |
| The data type of the value. |
| |
| Note |
| ---- |
| When dtype is None, we use the following rule: |
| |
| - int maps to "int32" |
| - float maps to "float32" |
| - bool maps to "bool" |
| - other using the same default rule as numpy. |
| """ |
| if isinstance(value, (_base.numeric_types, (bool, list))): |
| value = _np.array(value, dtype=dtype) |
| if not dtype: |
| # when dtype is None: int maps to "int32", float maps to "float32" |
| map_dtype = { |
| _np.dtype('int64'): _np.int32, |
| _np.dtype('float64'): _np.float32 |
| }.get(value.dtype, None) |
| if map_dtype: |
| value = value.astype(map_dtype) |
| if isinstance(value, (_np.ndarray, _np.generic)): |
| value = _nd.array(value) |
| |
| if not isinstance(value, _nd.NDArray): |
| raise ValueError("value has to be scalar or NDArray") |
| return Constant(value) |
| |
| |
| def bind(expr, binds): |
| """Bind an free variables in expr or function arguments. |
| |
| We can bind parameters expr if it is a function. |
| |
| Parameters |
| ---------- |
| expr : tvm.relay.Expr |
| The input expression. |
| |
| binds : Union[Map[tvm.relay.Var, tvm.relay.Expr], Map[str, tvm.relay.Expr]] |
| The specific bindings. |
| |
| Returns |
| ------- |
| result : tvm.relay.Expr |
| The expression or function after binding. |
| """ |
| return _expr.Bind(expr, binds) |