blob: daad8514f9ff5ca39e4a45f969bb077392eda8ff [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/expr.h
* \brief Relay expression language.
*/
#ifndef TVM_RELAY_EXPR_H_
#define TVM_RELAY_EXPR_H_
#include <tvm/ir/attrs.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <tvm/ir/op.h>
#include <functional>
#include <stack>
#include <string>
#include <utility>
#include "./base.h"
#include "./type.h"
namespace tvm {
namespace relay {
using Expr = tvm::RelayExpr;
using ExprNode = tvm::RelayExprNode;
using BaseFunc = tvm::BaseFunc;
using BaseFuncNode = tvm::BaseFuncNode;
using GlobalVar = tvm::GlobalVar;
using GlobalVarNode = tvm::GlobalVarNode;
using tvm::PrettyPrint;
/*!
* \brief Constant tensor, backed by an NDArray on the cpu(0) device.
*
* \note Scalar constants are represented by rank-0 const tensor.
* Constant folding are handled uniformly via Tensor types.
*/
class Constant;
/*!
* \brief Constant tensor type.
*/
class ConstantNode : public ExprNode {
public:
/*! \brief The data of the tensor */
runtime::NDArray data;
/*! \return The corresponding tensor type of the data */
TensorType tensor_type() const;
/*! \return Whether it is scalar(rank-0 tensor) */
bool is_scalar() const { return data->ndim == 0; }
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
return equal(data, other->data);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }
static constexpr const char* _type_key = "relay.Constant";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};
class Constant : public Expr {
public:
/*!
* \brief The constructor
* \param data The data of the constant tensor.
* \param span The source span of the expression.
*/
TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
};
/*! \brief Tuple of multiple Exprs */
class Tuple;
/*! \brief Tuple container */
class TupleNode : public ExprNode {
public:
/*! \brief the fields of the tuple */
tvm::Array<relay::Expr> fields;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const {
// specially handle empty tuple as a constant is not a graph node.
if (fields.size() == other->fields.size() && fields.size() == 0) {
return true;
} else {
equal->MarkGraphNode();
return equal(fields, other->fields);
}
}
void SHashReduce(SHashReducer hash_reduce) const {
if (fields.size() != 0) {
hash_reduce->MarkGraphNode();
hash_reduce(fields);
}
}
static constexpr const char* _type_key = "relay.Tuple";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode);
};
class Tuple : public Expr {
public:
/*!
* \brief The constructor
* \param fields The fields of a tuple.
* \param span The source span of the expression.
*/
TVM_DLL explicit Tuple(tvm::Array<relay::Expr> fields, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
};
/*!
* \brief Local variables used in the let expression.
*
* Its semantics are similar to tvm.Var node used in TVM's low level
* tensor expression language.
*
* \note Each Var is bind only once and is immutable.
*/
class Var;
/*! \brief Container for Var */
class VarNode : public ExprNode {
public:
/*!
* \brief The unique identifier of the Var.
*
* vid will be preserved for the same Var during type inference
* and other rewritings, while the VarNode might be recreated
* to attach additional information.
* This property can be used to keep track of parameter Var
* information across passes.
*/
Id vid;
/*!
* \brief type annotaion of the variable.
* This field records user provided type annotation of the Var.
* This field is optional and can be None.
*/
Type type_annotation;
/*! \return The name hint of the variable */
const String& name_hint() const { return vid->name_hint; }
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("vid", &vid);
v->Visit("type_annotation", &type_annotation);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(type_annotation, other->type_annotation) && equal(vid, other->vid);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(type_annotation);
hash_reduce(vid);
}
static constexpr const char* _type_key = "relay.Var";
TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
};
class Var : public Expr {
public:
/*!
* \brief The constructor
* \param name_hint The name hint of a variable.
* \param type_annotation The type annotation of a variable.
* \param span The source span of the expression.
*/
TVM_DLL Var(String name_hint, Type type_annotation, Span span = Span())
: Var(Id(name_hint), type_annotation, span) {}
/*!
* \brief The constructor
* \param vid The unique id of a variable.
* \param type_annotation The type annotation of a variable.
* \param span The source span of the expression.
*/
TVM_DLL Var(Id vid, Type type_annotation, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
};
/*!
* \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology.
*/
class Call;
/*! \brief Call container. */
class CallNode : public ExprNode {
protected:
// CallNode uses own deleter to indirectly call non-recursive destructor
Object::FDeleter saved_deleter_;
static void Deleter_(Object* ptr);
public:
/*!
* \brief The operator(function) being invoked
*
* - It can be tvm::Op which corresponds to the primitive operators.
* - It can also be user defined functions (Function, GlobalVar, Var).
*/
Expr op;
/*! \brief The arguments(inputs) of the call */
tvm::Array<relay::Expr> args;
/*! \brief The additional attributes */
Attrs attrs;
/*!
* \brief The type arguments passed to polymorphic(template) function.
*
* This is the advance feature that is only used when the function is
* polymorphic. It is safe to be ignored in most cases. For example, in the
* following code, the type_args of addone call is [int].
*
* \code
*
* template<typename T>
* T addone(T a) { return a + 1; }
*
* void main() {
* int x = addone<int>(10);
* }
*
* \endcode
*/
tvm::Array<Type> type_args;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("op", &op);
v->Visit("args", &args);
v->Visit("attrs", &attrs);
v->Visit("type_args", &type_args);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
// skip type_args check for primitive ops.
equal->MarkGraphNode();
return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) &&
(IsPrimitiveOp(op) || equal(type_args, other->type_args));
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(op);
hash_reduce(args);
hash_reduce(attrs);
if (!IsPrimitiveOp(op)) {
hash_reduce(type_args);
}
}
static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
friend class Call;
};
class Call : public Expr {
public:
/*!
* \brief The destructor
*/
~Call();
/*!
* \brief The constructor
* \param op The operator will be invoked.
* \param args The arguments of the call.
* \param attrs The attributes of the call node.
* \param type_args The type arguments passed to a polymorphic function.
* \param span The source span of the expression.
*/
TVM_DLL Call(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
Array<Type> type_args = Array<Type>(), Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode);
};
/*!
* \brief Let binding that binds a local var and optionally a type annotation.
*
* \note Let is useful to transform the program to be A-normal form.
* where each of the expression corresponds to a let binding.
*
* For developers who are familar with the computational graph.
* Each of the let can be viewed as a operator node in the computational graph.
* Traversing the list of let bindings is similar to running
* PostDFS-order(topo-order) traversal on the computational graph.
*/
class Let;
/*! \brief A binding of a sub-network. */
class LetNode : public ExprNode {
public:
/*! \brief The variable we bind to */
Var var;
/*! \brief The value we bind var to */
Expr value;
/*! \brief The body of the let binding */
Expr body;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal.DefEqual(var, other->var) && equal(value, other->value) &&
equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce.DefHash(var);
hash_reduce(value);
hash_reduce(body);
}
static constexpr const char* _type_key = "relay.Let";
TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
};
class Let : public Expr {
public:
/*!
* \brief The constructor
* \param var The variable that is bound to.
* \param value The value used to bind to the variable.
* \param body The body of the let binding.
* \param span The source span of the expression.
*/
TVM_DLL Let(Var var, Expr value, Expr body, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Let, RelayExpr, LetNode);
};
/*!
* \brief Condition expression
*
* Unlike traditional statement `if`s, the if evalutes
* to the result of the branch taken.
*
* let x = if (true) { 1 } else { 0 }; // x is 1
* let y = if (false) { 1 } else { 0 }; // y is 0
*
* \note This is similar to C's ternary operator.
*/
class If;
/*! \brief container of If */
class IfNode : public ExprNode {
public:
/*! \brief The condition */
Expr cond;
/*! \brief The expression evaluated when condition is true. */
Expr true_branch;
/*! \brief The expression evaluated when condition is false */
Expr false_branch;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
v->Visit("false_branch", &false_branch);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(cond, other->cond) && equal(true_branch, other->true_branch) &&
equal(false_branch, other->false_branch);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(cond);
hash_reduce(true_branch);
hash_reduce(false_branch);
}
static constexpr const char* _type_key = "relay.If";
TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
};
class If : public Expr {
public:
/*!
* \brief The constructor
* \param cond The condition of a if node.
* \param true_branch The fall through branch
* \param false_branch The branch for execution when condition is false.
* \param span The source span of the expression.
*/
TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode);
};
/*! \brief Get index-th field out of a tuple. */
class TupleGetItem;
class TupleGetItemNode : public ExprNode {
public:
/*! \brief The tuple Expression */
Expr tuple;
/*! \brief which value to get */
int index;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("tuple_value", &tuple);
v->Visit("index", &index);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const {
return equal(tuple, other->tuple) && equal(index, other->index);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(tuple);
hash_reduce(index);
}
static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode);
};
class TupleGetItem : public Expr {
public:
/*!
* \brief The constructor
* \param tuple The tuple to get an element from.
* \param index The index for extracting a value in the tuple.
* \param span The source span of the expression.
*/
TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, RelayExpr, TupleGetItemNode);
};
/*! \brief Create a new Reference out of initial value. */
class RefCreate;
class RefCreateNode : public ExprNode {
public:
/*! \brief The initial value of the Reference. */
Expr value;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const RefCreateNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(value, other->value);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(value);
}
static constexpr const char* _type_key = "relay.RefCreate";
TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode);
};
class RefCreate : public Expr {
public:
/*!
* \brief The constructor
* \param value The initial value of the reference.
* \param span The source span of the expression.
*/
TVM_DLL explicit RefCreate(Expr value, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, RelayExpr, RefCreateNode);
};
/*! \brief Get value out of Reference. */
class RefRead;
class RefReadNode : public ExprNode {
public:
/*! \brief The Reference Expression. */
Expr ref;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ref", &ref);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const RefReadNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(ref, other->ref);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(ref);
}
static constexpr const char* _type_key = "relay.RefRead";
TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode);
};
class RefRead : public Expr {
public:
/*!
* \brief The constructor
* \param ref The reference where to read data.
* \param span The source span of the expression.
*/
TVM_DLL explicit RefRead(Expr ref, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(RefRead, RelayExpr, RefReadNode);
};
/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
class RefWrite;
class RefWriteNode : public ExprNode {
public:
/*! \brief The Reference Expression. */
Expr ref;
/*! \brief The value to write into. */
Expr value;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ref", &ref);
v->Visit("value", &value);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const RefWriteNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(ref, other->ref) && equal(value, other->value);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(ref);
hash_reduce(value);
}
static constexpr const char* _type_key = "relay.RefWrite";
TVM_DECLARE_FINAL_OBJECT_INFO(RefWriteNode, ExprNode);
};
class RefWrite : public Expr {
public:
/*!
* \brief The constructor
* \param ref The reference where data is write to.
* \param value The value to write.
* \param span The source span of the expression.
*/
TVM_DLL RefWrite(Expr ref, Expr value, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, RelayExpr, RefWriteNode);
};
/*!
* \brief Base class of the temporary expression.
*
* TempExprs are pass specific expression that can be
* useful to define intermediate result in the
* rewriting pass such as layout or type transformation.
*
* Subclass TempExprNode allows us to pattern match on
* specific kind of TempExpr and use them for expression rewriting.
*
* TempExpr should only be used within a pass,
*/
class TempExprNode : public ExprNode {
public:
/*! \brief virtual destructor */
virtual ~TempExprNode() {}
/*!
* \brief Convert the expression to a normal(non-temp) Expr.
* \return The corresponding normal(non-temp) expression.
*/
virtual Expr Realize() const = 0;
static constexpr const char* _type_key = "relay.TempExpr";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
static constexpr const uint32_t _type_child_slots = 0;
TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
};
class TempExpr : public Expr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode);
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_