blob: 96dde5acb4dfeb7e7943c604416831ff9d732dee [file] [log] [blame]
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
from .base import RelayNode, register_relay_node
from . import _make
class Type(RelayNode):
"""The base type for all Relay types."""
def __eq__(self, other):
"""Compare two Relay types for structural equivalence using
alpha equivalence.
"""
return bool(_make._type_alpha_equal(self, other))
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Compares two Relay types by referential equality."""
return super().__eq__(other)
@register_relay_node
class TensorType(Type):
"""A concrete TensorType in Relay.
This is the type assigned to tensor's with a known dype and shape. For
example a tensor of `float32` and `(5, 5)`.
Parameters
----------
shape : List[tvm.Expr]
The shape of the Tensor
dtype : Optional[str]
The content data type.
Default to "float32".
Returns
-------
tensor_type : tvm.relay.TensorType
The tensor type.
"""
def __init__(self, shape, dtype="float32"):
self.__init_handle_by_constructor__(
_make.TensorType, shape, dtype)
@property
def concrete_shape(self):
"""Get shape of the type as concrete tuple of int.
Returns
-------
shape : List[int]
The concrete shape of the Type.
Raises
------
TypeError : If the shape is symbolic
"""
return tuple(int(x) for x in self.shape)
class Kind(IntEnum):
"""The kind of a type parameter, represents a variable shape,
base type, type, or dimension.
This controls what a type parameter is allowed to be instantiated
with. For example one's of kind BaseType can only be `float32`, `int32`,
and so on.
"""
Type = 0
ShapeVar = 1
BaseType = 2
Shape = 3
@register_relay_node
class TypeVar(Type):
"""A type variable used for generic types in Relay,
see tvm/relay/type.h for more details.
A type variable represents a type placeholder which will
be filled in later on. This allows the user to write
functions which are generic over types.
"""
def __init__(self, var, kind=Kind.Type):
"""Construct a TypeVar.
Parameters
----------
var : tvm.expr.Var
The tvm.Var which backs the type parameter.
kind : Optional[Kind]
The kind of the type parameter.
Default to Kind.Type.
Returns
-------
type_var : tvm.relay.TypeVar
The type variable.
"""
self.__init_handle_by_constructor__(_make.TypeVar, var, kind)
@register_relay_node
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""
pass
@register_relay_node
class TupleType(Type):
"""A tuple type in Relay, see tvm/relay/type.h for more details.
Lists the type of each field in the tuple.
"""
def __init__(self, fields):
"""Constructs a tuple type
Parameters
----------
fields : List[tvm.relay.Type]
The fields in the tuple
Returns
-------
tuple_type : tvm.relay.TupleType
the tuple type
"""
self.__init_handle_by_constructor__(_make.TupleType, fields)
@register_relay_node
class FuncType(Type):
"""A function type in Relay, see tvm/relay/type.h for more details.
This is the type assigned to functions in Relay. They consist of
a list of type parameters which enable the definition of generic
functions, a set of type constraints which we omit for the time
being, a sequence of argument types, and a return type.
We informally write them as:
`forall (type_params), (arg_types) -> ret_type where type_constraints`
Parameters
----------
arg_types : List[tvm.relay.Type]
The argument types
ret_type : tvm.relay.Type
The return type.
type_params : Optional[List[tvm.relay.TypeVar]]
The type parameters
type_constraints : Optional[List[tvm.relay.TypeConstraint]]
The type constraints.
"""
def __init__(self,
arg_types,
ret_type,
type_params=None,
type_constraints=None):
if type_params is None:
type_params = []
if type_constraints is None:
type_constraints = []
self.__init_handle_by_constructor__(
_make.FuncType, arg_types, ret_type, type_params, type_constraints)
@register_relay_node
class IncompleteType(Type):
"""An incomplete type."""
def __init__(self, kind=Kind.Type):
self.__init_handle_by_constructor__(_make.IncompleteType, kind)
@register_relay_node
class TypeRelation(TypeConstraint):
"""Type relation in relay.
Parameters
----------
func : EnvFunc
User defined relation function.
args : [tvm.relay.Type]
List of types to the func.
num_inputs : int
Number of input arguments in args,
this act as a hint for type inference.
attrs : Attrs
The attribute attached to the relation information
Returns
-------
type_relation : tvm.relay.TypeRelation
The type relation.
"""
def __init__(self, func, args, num_inputs, attrs):
self.__init_handle_by_constructor__(_make.TypeRelation,
func, args, num_inputs, attrs)
def scalar_type(dtype):
"""Creates a scalar type.
This function returns TensorType((), dtype)
Parameters
----------
dtype : str
The content data type.
Returns
-------
s_type : tvm.relay.TensorType
The result type.
"""
return TensorType((), dtype)