blob: f3bcf2c0a1d9f7e0a504e32eb0c04ccd7d214b18 [file] [log] [blame]
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/type.h
* \brief Relay typed AST nodes.
*/
#ifndef TVM_RELAY_TYPE_H_
#define TVM_RELAY_TYPE_H_
#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/node/node.h>
#include <string>
#include "base.h"
#include "../attrs.h"
namespace tvm {
namespace relay {
/*! \brief Base type of the Relay type hiearchy. */
class TypeNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Type";
TVM_DECLARE_BASE_NODE_INFO(TypeNode, Node);
};
/*!
* \brief Type is the base type of relay type hiearchy.
*
* Relay's type system contains following two key concepts:
*
* - TensorType: type of certain Tensor values in the expression.
* - FunctionType: the type of the function.
*
* There are also advanced types to support generic(polymorphic types),
* which can be ignored when first reading the code base.
*/
class Type : public NodeRef {
public:
Type() {}
explicit Type(NodePtr<tvm::Node> p) : NodeRef(p) {}
using ContainerType = TypeNode;
};
/*!
* \brief Base of all Tensor types
* This container can hold TensorType or GenericTensorType.
*/
class BaseTensorTypeNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.BaseTensorType";
TVM_DECLARE_BASE_NODE_INFO(BaseTensorTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(BaseTensorType, BaseTensorTypeNode, Type);
/*!
* \brief This is the most commonly used type in relay.
* TensorType have a fixed dimension, data type.
*
* The elements of shape can be either IntImm(constant integer),
* or any symbolic integer expression.
* The symbolic integer allows generic shape inference in certain cases.
* \sa TensorTypeNode The container class of TensorType.
*/
class TensorType;
/*! \brief TensorType container node */
class TensorTypeNode : public BaseTensorTypeNode {
public:
/*!
* \brief The shape of the tensor,
* represented by IndexExpr(tvm::Expr).
*/
Array<IndexExpr> shape;
/*! \brief The content data type */
DataType dtype;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}
/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
TVM_DLL IndexExpr Size() const;
TVM_DLL static TensorType make(Array<IndexExpr> shape, DataType dtype);
/*! \brief Construct an scalar containing elements of dtype. */
TVM_DLL static TensorType Scalar(DataType dtype);
static constexpr const char* _type_key = "relay.TensorType";
TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode);
};
RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
/*!
* \brief Type parameter in the function.
* This can be viewed as template parameter in c++ template function.
*
* For example, in the following pesudo code,
* the TypeVar of f is TypeVar(kind=kShapeVar, var=n).
* This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,)
*
* \code
*
* template<i32 n>
* f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
*
* \endcode
* \sa TypeVarNode The actual container class of TypeVar
*/
class TypeVar;
/*! \brief TypeVar container node */
class TypeVarNode : public TypeNode {
public:
/*! \brief possible kinds of TypeVar */
enum Kind : int {
/*! \brief template variable in shape expression */
kType = 0,
kShapeVar = 1,
kBaseType = 2,
kShape = 3
};
/*!
* \brief The variable itself is only meaningful when
* kind is ShapeVar, otherwise, we only use the name.
*/
tvm::Var var;
/*! \brief The kind of type parameter */
Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("kind", &kind);
v->Visit("span", &span);
}
TVM_DLL static TypeVar make(std::string name, Kind kind);
static constexpr const char* _type_key = "relay.TypeVar";
TVM_DECLARE_NODE_TYPE_INFO(TypeVarNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type);
/*!
* \brief IncompleteType.
* This is intermediate values that is used during type inference.
*
* If we view the type relations as "computational graph of types",
* then IncompleteType represents intermediate values of the graph,
* TypeVar represents the input to the graph.
*/
class IncompleteType;
/*! \brief IncompleteType container node */
class IncompleteTypeNode : public TypeNode {
public:
TypeVarNode::Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("kind", &kind);
v->Visit("span", &span);
}
TVM_DLL static IncompleteType make(TypeVarNode::Kind kind);
static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type);
/*!
* \brief Potential Constraints in the type.
* \note This is reserved for future use.
*/
class TypeConstraint;
/*! \brief TypeConstraint container node. */
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.TypeConstraint";
TVM_DECLARE_BASE_NODE_INFO(TypeConstraintNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeConstraint, TypeConstraintNode, Type);
class FuncType;
/*!
* \brief Function type in Relay.
*
* Relay support polymorphic function type.
* This can be roughly viewed as template function in C++.
*
* \sa TypeVar, TypeConstraint
*/
class FuncTypeNode : public TypeNode {
public:
/*! \brief type type of arguments */
tvm::Array<Type> arg_types;
/*! \brief The type of return value. */
Type ret_type;
// The following fields are used in polymorphic(template) functions
// For normal functions, the following two fields will be empty.
/*! \brief The type parameters of the function */
tvm::Array<TypeVar> type_params;
/*!
* \brief potential constraint the type need to obey
* \note this field is reserved for futher purposes.
*/
tvm::Array<TypeConstraint> type_constraints;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("arg_types", &arg_types);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("type_constraints", &type_constraints);
v->Visit("span", &span);
}
TVM_DLL static FuncType make(tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints);
static constexpr const char* _type_key = "relay.FuncType";
TVM_DECLARE_NODE_TYPE_INFO(FuncTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type);
/*!
* \brief The type of tuple values.
*/
class TupleType;
/*!
* \brief TupleType container.
*/
class TupleTypeNode : public TypeNode {
public:
/*! \brief The type of each field in the tuple. */
tvm::Array<Type> fields;
TupleTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("fields", &fields);
v->Visit("span", &span);
}
TVM_DLL static TupleType make(tvm::Array<Type> fields);
static constexpr const char* _type_key = "relay.TupleType";
TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);
class TypeReporter;
/*!
* \brief reporter that reports back to the
* type resolution information.
*/
class TypeReporterNode : public Node {
public:
/*!
* \brief Create a type equality constraint.
*
* The "assign direction" acts as a hint to the solver
* showing that it is more likely to resolve dst by src.
* But it is possible for the solver to resolve src by dst as well.
*/
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
/*!
* \brief assert shape expression comparison.
* \note Use assert only if any of the condition input is symbolic.
* \param cond The condition of operation.
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
*/
TVM_DLL virtual bool Assert(const IndexExpr& cond)= 0;
/*!
* \brief assert shape expression equals each other.
* \param lhs The left operand.
* \param rhs The right operand.
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
*/
TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;
/*!
* \brief Set the location at which to report unification errors.
* \param ref The program node to report the error.
*/
TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;
// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}
static constexpr const char* _type_key = "relay.TypeReporter";
TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node);
};
/*!
* \brief Container class of TypeReporter.
* \sa TypeReporterNode
*/
class TypeReporter : public NodeRef {
public:
TypeReporter() {}
explicit TypeReporter(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {
}
TypeReporterNode* operator->() const {
return static_cast<TypeReporterNode*>(node_.get());
}
using ContainerType = TypeReporterNode;
};
/*!
* \brief User defined type constraint function.
*
* If the input type information can be used to fully decide
* the IncompleteTypes, then the function should call
* reporter.Assign to report the new types, and return true.
* Otherwise, the function should return false.
*
* \param args The arguments to the relation.
* The types are stored in the form of
* [input_type_0, input_type_1, ... input_type_n,
* output_type_0, output_type_1, ... output_type_m]
*
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return false if This relation cannot be resolved.
* true if this relation has been resolved.
*/
using TypeRelationFn =
TypedEnvFunc<bool(const Array<Type>& args,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter)>;
/*!
* \brief User defined type relation, is an input-output relation on types.
*/
class TypeRelation;
/*!
* \brief TypeRelation container.
* \note This node is not directly serializable.
* The type function need to be lookedup in the module.
*/
class TypeRelationNode : public TypeConstraintNode {
public:
/*!
* \brief The function on input and output variables which
* this is not directly serializable,
* need to be looked-up in the module.
*/
TypeRelationFn func;
/*! \brief The type arguments to the type function. */
tvm::Array<Type> args;
/*! \brief Number of inputs arguments */
int num_inputs;
/*! \brief Attributes to the relation function */
Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("num_inputs", &num_inputs);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
}
TVM_DLL static TypeRelation make(TypeRelationFn func,
Array<Type> args,
int num_args,
Attrs attrs);
static constexpr const char* _type_key = "relay.TypeRelation";
TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, TypeConstraintNode);
};
RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint);
// The following fields contains advanced typing
// Only keep the class name and reserved for future usage.
class GenericTensorType;
// stores a DataType.
class GenericDataType;
// stores a DataType.
class GenericShape;
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_TYPE_H_