blob: 466a4f00fd5fe55bf36a3217ca09af5d9950fd20 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir/adt.h
* \brief Algebraic data type definitions.
*
* We adopt relay's ADT definition as a unified class
* for decripting structured data.
*/
#ifndef TVM_IR_ADT_H_
#define TVM_IR_ADT_H_
#include <tvm/ir/expr.h>
#include <tvm/ir/type.h>
#include <tvm/node/container.h>
#include <tvm/node/node.h>
#include <tvm/runtime/object.h>
#include <string>
namespace tvm {
/*!
* \brief ADT constructor.
* Constructors compare by pointer equality.
* \sa Constructor
*/
class ConstructorNode : public RelayExprNode {
public:
/*! \brief The name (only a hint) */
String name_hint;
/*! \brief Input to the constructor. */
Array<Type> inputs;
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
/*! \brief Index in the table of constructors (set when the type is registered). */
mutable int32_t tag = -1;
ConstructorNode() {}
void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("inputs", &inputs);
v->Visit("belong_to", &belong_to);
v->Visit("tag", &tag);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const ConstructorNode* other, SEqualReducer equal) const {
// Use namehint for now to be consistent with the legacy relay impl
// TODO(tvm-team) revisit, need to check the type var.
return equal(name_hint, other->name_hint) && equal(inputs, other->inputs);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name_hint);
hash_reduce(inputs);
}
static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode);
};
/*!
* \brief Managed reference to ConstructorNode
* \sa ConstructorNode
*/
class Constructor : public RelayExpr {
public:
/*!
* \brief Constructor
* \param name_hint the name of the constructor.
* \param inputs The input types.
* \param belong_to The data type var the constructor will construct.
*/
TVM_DLL Constructor(String name_hint, Array<Type> inputs, GlobalTypeVar belong_to);
TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode);
};
/*! \brief TypeData container node */
class TypeDataNode : public TypeNode {
public:
/*!
* \brief The header is simply the name of the ADT.
* We adopt nominal typing for ADT definitions;
* that is, differently-named ADT definitions with same constructors
* have different types.
*/
GlobalTypeVar header;
/*! \brief The type variables (to allow for polymorphism). */
Array<TypeVar> type_vars;
/*! \brief The constructors. */
Array<Constructor> constructors;
void VisitAttrs(AttrVisitor* v) {
v->Visit("header", &header);
v->Visit("type_vars", &type_vars);
v->Visit("constructors", &constructors);
v->Visit("span", &span);
}
bool SEqualReduce(const TypeDataNode* other, SEqualReducer equal) const {
return equal.DefEqual(header, other->header) && equal.DefEqual(type_vars, other->type_vars) &&
equal(constructors, other->constructors);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(header);
hash_reduce.DefHash(type_vars);
hash_reduce(constructors);
}
static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
};
/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
*
* In particular, it stores the handle (global type var) for an ADT
* and the constructors used to build it and is kept in the module. Note
* that type parameters are also indicated in the type data: this means that
* for any instance of an ADT, the type parameters must be indicated. That is,
* an ADT definition is treated as a type-level function, so an ADT handle
* must be wrapped in a TypeCall node that instantiates the type-level arguments.
* The kind checker enforces this.
*/
class TypeData : public Type {
public:
/*!
* \brief Constructor
* \param header the name of ADT.
* \param type_vars type variables.
* \param constructors constructors field.
*/
TVM_DLL TypeData(GlobalTypeVar header, Array<TypeVar> type_vars, Array<Constructor> constructors);
TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode);
};
} // namespace tvm
#endif // TVM_IR_ADT_H_