blob: cdb8e52d2359f922846578c6d4a6449fd75e2b7f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/adt.h
* \brief Algebraic data types for Relay
*/
#ifndef TVM_RELAY_ADT_H_
#define TVM_RELAY_ADT_H_
#include <tvm/ir/adt.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <functional>
#include <string>
#include <utility>
namespace tvm {
namespace relay {
using Constructor = tvm::Constructor;
using ConstructorNode = tvm::ConstructorNode;
using TypeData = tvm::TypeData;
using TypeDataNode = tvm::TypeDataNode;
/*! \brief Base type for declaring relay pattern. */
class PatternNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Pattern";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object);
};
/*!
* \brief Pattern is the base type for an ADT match pattern in Relay.
*
* Given an ADT value, a pattern might accept it and bind the pattern variable to some value
* (typically a subnode of the input or the input). Otherwise, the pattern rejects the value.
*
* ADT pattern matching thus takes a list of values and binds to the first that accepts the value.
*/
class Pattern : public ObjectRef {
public:
Pattern() {}
explicit Pattern(ObjectPtr<tvm::Object> p) : ObjectRef(p) {}
using ContainerType = PatternNode;
};
/*! \brief A wildcard pattern: Accepts all input and binds nothing. */
class PatternWildcard;
/*! \brief PatternWildcard container node */
class PatternWildcardNode : public PatternNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); }
bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { return true; }
void SHashReduce(SHashReducer hash_reduce) const {}
static constexpr const char* _type_key = "relay.PatternWildcard";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode);
};
class PatternWildcard : public Pattern {
public:
/* \brief Overload the default constructors. */
TVM_DLL PatternWildcard();
explicit PatternWildcard(ObjectPtr<Object> n) : Pattern(n) {}
/* \brief Copy constructor. */
PatternWildcard(const PatternWildcard& pat) : PatternWildcard(pat.data_) {}
/* \brief Move constructor. */
PatternWildcard(PatternWildcard&& pat) : PatternWildcard(std::move(pat.data_)) {}
/* \brief Copy assignment. */
PatternWildcard& operator=(const PatternWildcard& other) {
(*this).data_ = other.data_;
return *this;
}
/* \brief Move assignment. */
PatternWildcard& operator=(PatternWildcard&& other) {
(*this).data_ = std::move(other.data_);
return *this;
}
const PatternWildcardNode* operator->() const {
return static_cast<const PatternWildcardNode*>(get());
}
using ContainerType = PatternWildcardNode;
};
/*! \brief A var pattern. Accept all input and bind to a var. */
class PatternVar;
/*! \brief PatternVar container node */
class PatternVarNode : public PatternNode {
public:
/*! \brief Variable that stores the matched value. */
tvm::relay::Var var;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("span", &span);
}
bool SEqualReduce(const PatternVarNode* other, SEqualReducer equal) const {
return equal.DefEqual(var, other->var);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(var); }
static constexpr const char* _type_key = "relay.PatternVar";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode);
};
class PatternVar : public Pattern {
public:
/*!
* \brief Constructor
* \param var The var to construct a pattern
*/
TVM_DLL explicit PatternVar(tvm::relay::Var var);
TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode);
};
/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
class PatternConstructor;
/*! \brief PatternVar container node */
class PatternConstructorNode : public PatternNode {
public:
/*! Constructor matched by the pattern. */
Constructor constructor;
/*! Sub-patterns to match against each input to the constructor. */
tvm::Array<Pattern> patterns;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("constructor", &constructor);
v->Visit("patterns", &patterns);
v->Visit("span", &span);
}
bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal) const {
return equal(constructor, other->constructor) && equal(patterns, other->patterns);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(constructor);
hash_reduce(patterns);
}
static constexpr const char* _type_key = "relay.PatternConstructor";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode);
};
class PatternConstructor : public Pattern {
public:
/*!
* \brief Constructor
* \param constructor The constructor of a pattern
* \param patterns The sub-patterns for matching
*/
TVM_DLL PatternConstructor(Constructor constructor, tvm::Array<Pattern> patterns);
TVM_DEFINE_OBJECT_REF_METHODS(PatternConstructor, Pattern, PatternConstructorNode);
};
/*! \brief A tuple pattern. Matches a tuple, binds recursively. */
class PatternTuple;
/*! \brief PatternVar container node */
class PatternTupleNode : public PatternNode {
public:
/* TODO(@jroesch): rename to field_pats */
/*! Sub-patterns to match against each value of the tuple. */
tvm::Array<Pattern> patterns;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("patterns", &patterns);
v->Visit("span", &span);
}
bool SEqualReduce(const PatternTupleNode* other, SEqualReducer equal) const {
return equal(patterns, other->patterns);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(patterns); }
static constexpr const char* _type_key = "relay.PatternTuple";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode);
};
class PatternTuple : public Pattern {
public:
/*!
* \brief Constructor
* \param patterns The sub-patterns to match against each value of the tuple
*/
TVM_DLL explicit PatternTuple(tvm::Array<Pattern> patterns);
TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode);
};
/*! \brief A clause in a match expression. */
class Clause;
/*! \brief Clause container node. */
class ClauseNode : public Object {
public:
/*! \brief The pattern the clause matches. */
Pattern lhs;
/*! \brief The resulting value. */
Expr rhs;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs);
}
bool SEqualReduce(const ClauseNode* other, SEqualReducer equal) const {
return equal(lhs, other->lhs) && equal(rhs, other->rhs);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(lhs);
hash_reduce(rhs);
}
static constexpr const char* _type_key = "relay.Clause";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object);
};
class Clause : public ObjectRef {
public:
/*!
* \brief Constructor
* \param lhs The pattern matched by the clause.
* \param rhs The resulting value
*/
TVM_DLL explicit Clause(Pattern lhs, Expr rhs);
TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ClauseNode);
};
/*!
* \brief Returns \p clause with the given properties. A null property denotes 'no change'.
* Returns \p clause if all properties are unchanged. Otherwise, returns a copy with the new
* fields.
*/
Clause WithFields(Clause clause, Optional<Pattern> opt_lhs = Optional<Pattern>(),
Optional<Expr> opt_rhs = Optional<Expr>());
/*! \brief ADT pattern matching exression. */
class Match;
/*! \brief Match container node. */
class MatchNode : public ExprNode {
public:
/*! \brief The input being deconstructed. */
Expr data;
/*! \brief The match node clauses. */
tvm::Array<Clause> clauses;
/*! \brief Should this match be complete (cover all cases)?
* If yes, the type checker will generate an error if there are any missing cases.
*/
bool complete;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("clauses", &clauses);
v->Visit("complete", &complete);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(data, other->data) && equal(clauses, other->clauses) &&
equal(complete, other->complete);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(data);
hash_reduce(clauses);
hash_reduce(complete);
}
static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode);
};
class Match : public Expr {
public:
/*!
* \brief Constructor
* \param data the input being deconstructed.
* \param clauses The clauses for matching.
* \param complete Indicate if this match is complete.
* \param span The span of the expression.
*/
TVM_DLL Match(Expr data, tvm::Array<Clause> clauses, bool complete = true, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchNode);
};
/*!
* \brief Returns \p match with the given properties. A null property denotes 'no change'.
* Returns \p match if all properties are unchanged. Otherwise, returns a copy with the new
* fields.
*/
Match WithFields(Match match, Optional<Expr> opt_data = Optional<Expr>(),
Optional<Array<Clause>> opt_clauses = Optional<Array<Clause>>(),
Optional<Bool> opt_complete = Optional<Bool>(),
Optional<Span> opt_span = Optional<Span>());
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ADT_H_