blob: 65b454f08b52f4abb52438531ac673434e2f5018 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir/type.h
* \brief IR/AST nodes for the unified type system in TVM.
*
* We use Relay's type system as the unified type system
* throughout the stack.
*
* This file contains types that are common across IR variants.
*
* ## Relation between Type and runtime::DataType
*
* Besides Type, we also store a dtype field in the low-level PrimExpr.
* runtime::DataType(dtype) provides coarse grained type information
* during compile time and runtime. It is eagerly built in
* low-level expression construction and can be used for
* quick type checking in the low-level IR.
* For example, when an Expr's dtype is int32,
* we know for sure that its type is also int32.
*
* On the other hand, Type provides more fine grained information.
* For example, a low level expression can have DataType::Handle() as
* its dtype and MemRef[float32] as its type.
* Types are usually lazily constructed via type checking,
* so they may not readily be available during IR construction.
*
* The unified Type serves as a common bridge across IR dialects.
* For example, we require all the functions to have a type signature,
* which allow us to build cross dialect function calls.
*/
#ifndef TVM_IR_TYPE_H_
#define TVM_IR_TYPE_H_
#include <tvm/ir/span.h>
#include <tvm/node/container.h>
#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
#include <string>
namespace tvm {
/*!
* \brief Type is the base type of all types.
*
* Relay's type system contains following subclasses:
*
* - PrimType: type of primitive type values used in the low-level IR.
* - FuncType: type of a function.
* - TensorType: type of certain Tensor values in the expression.
*
* There are also advanced types to support generic(polymorphic types).
* \sa Type
*/
class TypeNode : public Object {
public:
/*!
* \brief Span that points to the original source code.
* Reserved debug information.
*/
mutable Span span;
static constexpr const char* _type_key = "Type";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 14;
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};
/*!
* \brief Managed reference to TypeNode.
* \sa TypeNode
*/
class Type : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Type, ObjectRef, TypeNode);
};
/*!
* \brief Primitive data types used in the low-level IR.
*
* PrimType represents POD-values and handles that are
* not automatically managed by the runtime.
*
* \sa PrimType
*/
class PrimTypeNode : public TypeNode {
public:
/*!
* \brief The corresponding dtype field.
*/
runtime::DataType dtype;
void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); }
bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); }
static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};
/*
* \brief Managed reference to PrimTypeNode.
* \sa PrimTypeNode
*/
class PrimType : public Type {
public:
/*!
* \brief Constructor
* \param dtype The corresponding dtype.
*/
TVM_DLL explicit PrimType(runtime::DataType dtype);
TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode);
};
/*!
* \brief Low-level raw pointer type.
*
* PointerType represents type hints in the TIR to be
* passed to the final code generator.
*
* PointerType should not occur in the high-level analysis.
*
* \sa PointerType
*/
class PointerTypeNode : public TypeNode {
public:
/*!
* \brief The type of the element which the pointer points to.
*/
Type element_type;
void VisitAttrs(AttrVisitor* v) { v->Visit("element_type", &element_type); }
bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const {
return equal(element_type, other->element_type);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(element_type); }
static constexpr const char* _type_key = "PointerType";
TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
};
/*
* \brief Managed reference to PointerTypeNode.
* \sa PointerTypeNode
*/
class PointerType : public Type {
public:
/*!
* \brief Constructor
* \param element_type The type of the element which the pointer points to.
*/
TVM_DLL explicit PointerType(Type element_type);
TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode);
};
/*! \brief Possible kinds of TypeVars. */
enum TypeKind : int {
kType = 0,
/*! \brief Template variable in shape expression. */
kShapeVar = 1,
kBaseType = 2,
kConstraint = 4,
kAdtHandle = 5,
kTypeData = 6
};
/*!
* \brief Type parameter in functions.
*
* A type variable can be viewed as template parameter in c++ template function.
*
* For example, in the following pesudo code,
* the TypeVar of f is TypeVar("n", kind=kShapeVar).
* This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,)
*
* \code
*
* template<i32 n>
* f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
*
* \endcode
* \sa TypeVar, TypeKind
*/
class TypeVarNode : public TypeNode {
public:
/*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality.
*/
String name_hint;
/*! \brief The kind of type parameter */
TypeKind kind;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind);
v->Visit("span", &span);
}
bool SEqualReduce(const TypeVarNode* other, SEqualReducer equal) const {
return equal(kind, other->kind) && equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(kind);
hash_reduce.FreeVarHashImpl(this);
}
static constexpr const char* _type_key = "TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};
/*!
* \brief Managed reference to TypeVarNode
* \sa TypeVarNode
*/
class TypeVar : public Type {
public:
/*!
* \brief Constructor
* \param name_hint The name of the type var.
* \param kind The kind of the type var.
*/
TVM_DLL TypeVar(String name_hint, TypeKind kind);
TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode);
};
/*!
* \brief A global type variable that is used for defining new types or type aliases.
* \sa GlobalTypeVar
*/
class GlobalTypeVarNode : public TypeNode {
public:
/*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality.
*/
String name_hint;
/*! \brief The kind of type parameter */
TypeKind kind;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind);
}
bool SEqualReduce(const GlobalTypeVarNode* other, SEqualReducer equal) const {
// name matters for now in global type var.
return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name_hint);
hash_reduce.FreeVarHashImpl(this);
}
static constexpr const char* _type_key = "GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};
/*!
* \brief Managed reference to GlobalTypeVarNode
* \sa GlobalTypeVarNode
*/
class GlobalTypeVar : public Type {
public:
/*!
* \brief Constructor
* \param name_hint The name of the type var.
* \param kind The kind of the type var.
*/
TVM_DLL GlobalTypeVar(String name_hint, TypeKind kind);
TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode);
};
/*!
* \brief The type of tuple values.
* \sa TupleType
*/
class TupleTypeNode : public TypeNode {
public:
/*! \brief The type of each field in the tuple. */
Array<Type> fields;
TupleTypeNode() {}
void VisitAttrs(AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("span", &span);
}
bool SEqualReduce(const TupleTypeNode* other, SEqualReducer equal) const {
return equal(fields, other->fields);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); }
static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
/*!
* \brief Managed reference to TupleTypeNode.
* \sa TupleTypeNode.
*/
class TupleType : public Type {
public:
/*!
* \brief Constructor
* \param fields Fields in the tuple.
*/
TVM_DLL explicit TupleType(Array<Type> fields);
/*!
* \brief Create an empty tuple type that constains nothing.
* \return A empty tuple type.
*/
TVM_DLL TupleType static Empty();
TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode);
};
/*!
* \return a type that represents void.
*/
inline Type VoidType() { return TupleType::Empty(); }
/*!
* \brief Check whether the tyep represents void.
* \return The check result.
*/
inline bool IsVoidType(const Type& type) {
auto* n = type.as<TupleTypeNode>();
return n && n->fields.size() == 0;
}
/*!
* \brief Potential Constraints in a function.
* \sa TypeConstraint
*/
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "TypeConstraint";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
};
/*!
* \brief Managed reference to TypeConstraintNode.
* \sa TypeConstraintNode, TypeRelation
*/
class TypeConstraint : public Type {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode);
};
/*!
* \brief Function type.
*
* We support polymorphic function type.
* This can be roughly viewed as template function in C++.
*
* \sa FuncType, TypeVar, TypeConstraint
*/
class FuncTypeNode : public TypeNode {
public:
/*! \brief type type of arguments */
Array<Type> arg_types;
/*! \brief The type of return value. */
Type ret_type;
// The following fields are used in polymorphic(template) functions
// For normal functions, the following two fields will be empty.
/*! \brief The type parameters of the function */
Array<TypeVar> type_params;
/*!
* \brief potential constraint the type need to obey
* \note this field is reserved for futher purposes.
*/
Array<TypeConstraint> type_constraints;
void VisitAttrs(AttrVisitor* v) {
v->Visit("arg_types", &arg_types);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("type_constraints", &type_constraints);
v->Visit("span", &span);
}
bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const {
// type params first as they defines type vars.
return equal.DefEqual(type_params, other->type_params) && equal(arg_types, other->arg_types) &&
equal(ret_type, other->ret_type) && equal(type_constraints, other->type_constraints);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(type_params);
hash_reduce(arg_types);
hash_reduce(ret_type);
hash_reduce(type_constraints);
}
static constexpr const char* _type_key = "FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};
/*!
* \brief Managed reference to FuncTypeNode.
* \sa FuncTypeNode
*/
class FuncType : public Type {
public:
/*!
* \brief Constructor
* \param arg_types The types of the arguments.
* \param ret_type The type of the return value.
* \param type_params The type parameters.
* \param type_constraints The type constraints.
* \sa FuncTypeNode for more docs about these fields.
*/
TVM_DLL FuncType(Array<Type> arg_types, Type ret_type, Array<TypeVar> type_params,
Array<TypeConstraint> type_constraints);
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
};
/*!
* \brief Intermediate values that is used to indicate incomplete type
* during type inference.
*
* If we view the type relations as "computational graph of types",
* then IncompleteType represents intermediate values of the graph,
* TypeVar represents the input to the graph.
*
* \sa IncompleteType
*/
class IncompleteTypeNode : public TypeNode {
public:
/*! \brief kind of the type. */
TypeKind kind;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("kind", &kind);
v->Visit("span", &span);
}
bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const {
return equal(kind, other->kind) && equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(kind); }
static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};
/*!
* \brief Managed reference to IncompleteTypeNode.
* \sa IncompleteTypeNode
*/
class IncompleteType : public Type {
public:
/*!
* \brief Constructor.
* \param kind kind of the type.
*/
TVM_DLL explicit IncompleteType(TypeKind kind);
TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode);
};
/*!
* \brief Reference Type High-level Relay IR.
*
* \sa RelayRefType.
*/
class RelayRefTypeNode : public TypeNode {
public:
/*! \brief The type of value in the Reference. */
Type value;
RelayRefTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("span", &span);
}
bool SEqualReduce(const RelayRefTypeNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
// Keep the relay prefix in the type as this type is specific
// to the relay itself.
static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode);
};
/*!
* \brief Managed reference to RelayRefTypeNode.
* \sa RelayRefTypeNode.
*/
class RelayRefType : public Type {
public:
TVM_DLL explicit RelayRefType(Type value);
TVM_DEFINE_OBJECT_REF_METHODS(RelayRefType, Type, RelayRefTypeNode);
};
} // namespace tvm
#endif // TVM_IR_TYPE_H_