blob: 35083cafae818796065d6ec90965d80f1d1a2930 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file tvm/expr.h
* \brief The Expr and related elements in DataFlow construction.
*/
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_
#include <ir/Expr.h>
#include <ir/IRPrinter.h>
#include <string>
#include <algorithm>
#include "base.h"
#include "runtime/c_runtime_api.h"
namespace tvm {
using HalideIR::Type;
using HalideIR::Float;
using HalideIR::Bool;
using HalideIR::Int;
using HalideIR::UInt;
using HalideIR::Handle;
using HalideIR::ExprHash;
using HalideIR::ExprEqual;
using HalideIR::Expr;
using HalideIR::VarExpr;
using HalideIR::IR::RangeNode;
using HalideIR::IR::FunctionRef;
using HalideIR::IR::FunctionBaseNode;
using HalideIR::Internal::IntImm;
using HalideIR::Internal::Stmt;
using HalideIR::Internal::IRPrinter;
using HalideIR::Internal::Variable;
inline Type TVMShapeIndexType() {
if (std::is_signed<tvm_index_t>::value) {
return Int(sizeof(tvm_index_t) * 8);
} else {
return UInt(sizeof(tvm_index_t) * 8);
}
}
inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halideir_type_code_t>(t.code), t.bits, t.lanes);
}
inline TVMType Type2TVMType(Type t) {
TVMType ret;
ret.code = static_cast<uint8_t>(t.code());
ret.bits = static_cast<uint8_t>(t.bits());
ret.lanes = static_cast<uint16_t>(t.lanes());
return ret;
}
// Get number of bytes considering vector type.
inline int GetVectorBytes(Type dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == Bool()) return 1;
CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes";
return data_bits / 8;
}
/*! \brief a named variable in TVM */
class Var : public HalideIR::VarExpr {
public:
EXPORT explicit Var(const std::string& name_hint = "v",
Type t = Int(32)) : VarExpr(name_hint, t) {}
explicit Var(NodePtr<Node> n) : VarExpr(n) {}
explicit Var(VarExpr v) : VarExpr(v) {}
/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
* \return the new Var copy
*/
Var copy_with_suffix(const std::string& suffix) const {
return Var((*this)->name_hint + suffix, (*this)->type);
}
/*! \brief type indicate the container type */
using ContainerType = Variable;
};
/*!
* \brief Container of constant integer (IntImm).
*
* This is used to store and automate type check
* attributes that must be constant integer.
*/
class Integer : public Expr {
public:
Integer() : Expr() {}
/*!
* \brief constructor from node.
*/
explicit Integer(NodePtr<Node> node) : Expr(node) {}
/*!
* \brief Construct integer from int value.
*/
Integer(int value) : Expr(value) {} // NOLINT(*)
/*!
* \brief Assign an expression to integer.
* \param other another expression.
*/
Integer& operator=(const Integer& other) {
node_ = other.node_;
return *this;
}
/*!
* \brief Get pointer to the internal value.
* \return the content of the integer.
*/
const IntImm* operator->() const {
return static_cast<const IntImm*>(node_.get());
}
/*!
* \brief convert to int64_t
*/
operator int64_t() const {
CHECK(node_ != nullptr)
<< " Trying get reference a null Integer";
return (*this)->value;
}
/*! \brief type indicate the container type */
using ContainerType = IntImm;
};
/*! \brief container class of iteration variable. */
class IterVarNode;
/*!
* \brief same as HalideIR::IR::Range
* except it provide an constructor with (begin, end)
*
* \note Traditional Halide's Range have a constructor with
* (begin, extent), which does not match the convention in e.g. python.
* We decided to correct it by removing the constructor in HalideIR,
* and add it back in TVM's range.
*/
class Range : public HalideIR::IR::Range {
public:
/*! \brief constructor */
Range() {}
explicit Range(NodePtr<Node> n) : HalideIR::IR::Range(n) {}
/*!
* \brief constructor by begin and end
* \param begin The begin of the range.
* \param end The end of the range.
*/
TVM_DLL Range(Expr begin, Expr end);
TVM_DLL static Range make_by_min_extent(Expr min, Expr extent);
};
using Region = Array<Range>;
/*!
* \brief Type of iteration variable.
* Each IterVar have a specific type.
*
* The type of iter var can be overriden via
* stage.iter_var_attrs given they are compatible.
*/
enum IterVarType : int {
/*!
* \brief Data parallel iteration.
* This normally corresponds to axis of Tensor.
* Allow all IterVar manipulations.
*
* \note This does not mean the loop
* have to be executed in parallel fashion.
*/
kDataPar = 0,
/*!
* \brief The IterVar itself is a thread-index
* of a fixed thread launching group.
* Note that this is already assumed to be paralellized.
*
* Disallow: split/fuse/vectorize/parallel
*/
kThreadIndex = 1,
/*!
* \brief Communicative reduction.
* Cannot be directly parallelized.
*
* Disallow: parallel/vectorize
*/
kCommReduce = 2,
/*!
* \brief Serial loops with loop carry dependency,
* the iteration must execute in order.
* Cannot be re-ordered.
*
* Disallow: reorder/parallel/vectorize
*/
kOrdered = 3,
/*!
* \brief IterVar is opaque,
*
* May not corresponds to any generated loop
* Disallow all IterVar manipulations and compute_at
*
* \note This is usually used to implement composite op
* or external op, where the
*/
kOpaque = 4,
// The following are possible additional
// types that are provided during schedule
/*!
* \brief The execution is unrolled.
*/
kUnrolled = 5,
/*!
* \brief The loop is vectorized.
*/
kVectorized = 6,
/*!
* \brief The loop is parallelized.
*/
kParallelized = 7,
/*!
* \brief Marks boundary of tensorization intrinsic.
*/
kTensorized = 8
};
/*!
* \brief Iteration Variable,
* represents an iteration over an integer interval.
*/
class IterVar : public NodeRef {
public:
// construct a new iter var without a domain
IterVar() {}
// construct from shared ptr.
explicit IterVar(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IterVarNode* operator->() const;
/*!
* \return the corresponding var in the IterVar.
*/
inline operator Expr() const;
/*! \brief specify container node */
using ContainerType = IterVarNode;
};
/*!
* \brief Create a new IterVar that represents an axis in thread.
*
* \param dom Optional, domain of the thread axis.
* \param tag The thread tag of the axis.
*/
TVM_DLL IterVar thread_axis(Range dom, std::string tag);
/*!
* \brief Create a new IterVar for reduction operations.
*
* \param dom The domain of the reduction axis.
* \param name The name of the reduction axis.
*/
TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv");
using Domain = Array<Range>;
// print functions for expr
TVM_DLL std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*)
/*!
* \brief Dump the node to stderr, used for debug purposes.
* \param node The input node
*/
TVM_DLL void Dump(const NodeRef& node);
// definition of Node.
/*!
* \brief An iteration variable representing an iteration
* over a one dimensional interval.
*/
class IterVarNode : public Node {
public:
/*!
* \brief the domain of iteration, if known, can be None
* For the intermediate schedule node, before schedule.
*/
Range dom;
/*! \brief The looping variable */
Var var;
/*! \brief The type of the IterVar */
IterVarType iter_type;
/*!
* \brief additional tag on the iteration variable,
* set this if this is binded already to a known thread tag.
*/
std::string thread_tag;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dom", &dom);
v->Visit("var", &var);
v->Visit("iter_type", &iter_type);
v->Visit("thread_tag", &thread_tag);
}
TVM_DLL static IterVar make(Range dom, Var var,
IterVarType iter_type,
std::string thread_tag = "");
static constexpr const char* _type_key = "IterVar";
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node);
};
// inline implementations
inline const IterVarNode* IterVar::operator->() const {
return static_cast<const IterVarNode*>(node_.get());
}
inline IterVar::operator Expr() const {
return (*this)->var;
}
inline const char* IterVarType2String(IterVarType t) {
switch (t) {
case kDataPar: return "DataPar";
case kThreadIndex: return "ThreadIndex";
case kCommReduce: return "CommReduce";
case kOrdered: return "Ordered";
case kOpaque: return "Opaque";
case kUnrolled: return "Unrolled";
case kVectorized: return "Vectorized";
case kParallelized: return "Parallelized";
case kTensorized: return "Tensorized";
}
return "Unknown";
}
/*!
* \brief Construct a new Var expression
* \param name_hint The name hint for the expression
* \param t The type of the expression
*/
TVM_DLL Var var(const std::string& name_hint, Type t = Int(32));
/*
* \brief Template function to convert Map to unordered_map
* Sometimes useful for API gluing when internal uses unordered_map
* \param dmap The container map
* \return The corresponding unordered_map.
* \tparam K the key of the Map.
* \tparam V the value of the Map.
*/
template<typename K, typename V>
inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
std::unordered_map<K, V> ret;
for (auto kv : dmap) {
ret[kv.first] = kv.second;
}
return ret;
}
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::IterVar> {
std::size_t operator()(const ::tvm::IterVar& k) const {
return k.hash();
}
};
}
#endif // TVM_EXPR_H_