blob: c4d5df7ac06c6658bf87ff06e2c88c4470450f38 [file] [log] [blame]
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
from .base import NodeBase, register_relay_node
from . import _make
class Type(NodeBase):
"""The base type for all Relay types."""
def __eq__(self, other):
"""Compare two Relay types for structural equivalence using
alpha equivalence.
"""
return bool(_make._type_alpha_eq(self, other))
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Compares two Relay types by referential equality."""
return super().__eq__(other)
@register_relay_node
class TensorType(Type):
"""A concrete TensorType in Relay, see tvm/relay/type.h for more details.
This is the type assigned to tensor's with a known dype and shape. For
example a tensor of `float32` and `(5, 5)`.
"""
def __init__(self, shape, dtype):
"""Construct a tensor type.
Parameters
----------
shape: list of tvm.Expr
dtype: str
Returns
-------
tensor_type: The TensorType
"""
self.__init_handle_by_constructor__(_make.TensorType, shape, dtype)
class Kind(IntEnum):
"""The kind of a type parameter, represents a variable shape,
base type, type, or dimension.
This controls what a type parameter is allowed to be instantiated
with. For example one's of kind BaseType can only be `float32`, `int32`,
and so on.
"""
ShapeVar = 0
Shape = 1
BaseType = 2
Type = 3
@register_relay_node
class TypeParam(Type):
"""A type parameter used for generic types in Relay,
see tvm/relay/type.h for more details.
A type parameter represents a type placeholder which will
be filled in later on. This allows the user to write
functions which are generic over types.
"""
def __init__(self, var, kind):
"""Construct a TypeParam.
Parameters
----------
var: tvm.expr.Var
The tvm.Var which backs the type parameter.
kind: Kind
The kind of the type parameter.
Returns
-------
type_param: TypeParam
The type parameter.
"""
self.__init_handle_by_constructor__(_make.TypeParam, var, kind)
@register_relay_node
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""
pass
@register_relay_node
class TupleType(Type):
"""A tuple type in Relay, see tvm/relay/type.h for more details.
Lists the type of each field in the tuple.
"""
def __init__(self, fields):
"""Constructs a tuple type
Parameters
----------
fields: list of tvm.Type
Returns
-------
tuple_type: the tuple type
"""
self.__init_handle_by_constructor__(_make.TupleType, fields)
@register_relay_node
class FuncType(Type):
"""A function type in Relay, see tvm/relay/type.h for more details.
This is the type assigned to functions in Relay. They consist of
a list of type parameters which enable the definition of generic
functions, a set of type constraints which we omit for the time
being, a sequence of argument types, and a return type.
We informally write them as:
`forall (type_params), (arg_types) -> ret_type where type_constraints`
"""
def __init__(self,
arg_types,
ret_type,
type_params,
type_constraints,
):
"""Construct a function type.
Parameters
----------
arg_types: list of Type
ret_type: Type
type_params: list of TypeParam
type_constraints: list of TypeConstraint
Returns
-------
func_type: FuncType
The function type.
"""
self.__init_handle_by_constructor__(
_make.FuncType, arg_types, ret_type, type_params, type_constraints)
@register_relay_node
class IncompleteType(Type):
"""An incomplete type."""
def __init__(self, kind=Kind.Type):
self.__init_handle_by_constructor__(_make.IncompleteType, kind)
@register_relay_node
class TypeRelation(TypeConstraint):
"""Type relation in relay.
Parameters
----------
func : EnvFunc
User defined relation function.
args : list of types
List of types to the func.
num_inputs: int
Number of input arguments in args,
this act as a hint for type inference.
attrs : Attrs
The attribute attached to the relation information
"""
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_make.TypeRelation,
func, args, num_inputs, attrs)