| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file tvm/expr.h |
| * \brief The Expr and related elements in DataFlow construction. |
| */ |
| #ifndef TVM_EXPR_H_ |
| #define TVM_EXPR_H_ |
| |
| #include <ir/Expr.h> |
| #include <ir/IRPrinter.h> |
| #include <string> |
| #include <algorithm> |
| #include "base.h" |
| #include "runtime/c_runtime_api.h" |
| |
| namespace tvm { |
| |
| using HalideIR::Type; |
| using HalideIR::Float; |
| using HalideIR::Bool; |
| using HalideIR::Int; |
| using HalideIR::UInt; |
| using HalideIR::Handle; |
| using HalideIR::ExprHash; |
| using HalideIR::ExprEqual; |
| |
| using HalideIR::Expr; |
| using HalideIR::VarExpr; |
| using HalideIR::IR::RangeNode; |
| using HalideIR::IR::FunctionRef; |
| using HalideIR::IR::FunctionBaseNode; |
| using HalideIR::Internal::IntImm; |
| using HalideIR::Internal::Stmt; |
| using HalideIR::Internal::IRPrinter; |
| using HalideIR::Internal::Variable; |
| |
| inline Type TVMShapeIndexType() { |
| if (std::is_signed<tvm_index_t>::value) { |
| return Int(sizeof(tvm_index_t) * 8); |
| } else { |
| return UInt(sizeof(tvm_index_t) * 8); |
| } |
| } |
| |
| inline Type TVMType2Type(TVMType t) { |
| return Type(static_cast<halideir_type_code_t>(t.code), t.bits, t.lanes); |
| } |
| |
| inline TVMType Type2TVMType(Type t) { |
| TVMType ret; |
| ret.code = static_cast<uint8_t>(t.code()); |
| ret.bits = static_cast<uint8_t>(t.bits()); |
| ret.lanes = static_cast<uint16_t>(t.lanes()); |
| return ret; |
| } |
| |
| // Get number of bytes considering vector type. |
| inline int GetVectorBytes(Type dtype) { |
| int data_bits = dtype.bits() * dtype.lanes(); |
| // allow bool to exist |
| if (dtype == Bool()) return 1; |
| CHECK_EQ(data_bits % 8, 0U) |
| << "Need to load/store by multiple of bytes"; |
| return data_bits / 8; |
| } |
| |
| /*! \brief a named variable in TVM */ |
| class Var : public HalideIR::VarExpr { |
| public: |
| EXPORT explicit Var(const std::string& name_hint = "v", |
| Type t = Int(32)) : VarExpr(name_hint, t) {} |
| explicit Var(NodePtr<Node> n) : VarExpr(n) {} |
| explicit Var(VarExpr v) : VarExpr(v) {} |
| /*! |
| * \brief Make a new copy of var with same type, append suffix |
| * \param suffix The suffix to be appended. |
| * \return the new Var copy |
| */ |
| Var copy_with_suffix(const std::string& suffix) const { |
| return Var((*this)->name_hint + suffix, (*this)->type); |
| } |
| /*! \brief type indicate the container type */ |
| using ContainerType = Variable; |
| }; |
| |
| |
| /*! |
| * \brief Container of constant integer (IntImm). |
| * |
| * This is used to store and automate type check |
| * attributes that must be constant integer. |
| */ |
| class Integer : public Expr { |
| public: |
| Integer() : Expr() {} |
| /*! |
| * \brief constructor from node. |
| */ |
| explicit Integer(NodePtr<Node> node) : Expr(node) {} |
| /*! |
| * \brief Construct integer from int value. |
| */ |
| Integer(int value) : Expr(value) {} // NOLINT(*) |
| /*! |
| * \brief Assign an expression to integer. |
| * \param other another expression. |
| */ |
| Integer& operator=(const Integer& other) { |
| node_ = other.node_; |
| return *this; |
| } |
| /*! |
| * \brief Get pointer to the internal value. |
| * \return the content of the integer. |
| */ |
| const IntImm* operator->() const { |
| return static_cast<const IntImm*>(node_.get()); |
| } |
| /*! |
| * \brief convert to int64_t |
| */ |
| operator int64_t() const { |
| CHECK(node_ != nullptr) |
| << " Trying get reference a null Integer"; |
| return (*this)->value; |
| } |
| /*! \brief type indicate the container type */ |
| using ContainerType = IntImm; |
| }; |
| |
| |
| /*! \brief container class of iteration variable. */ |
| class IterVarNode; |
| |
| /*! |
| * \brief same as HalideIR::IR::Range |
| * except it provide an constructor with (begin, end) |
| * |
| * \note Traditional Halide's Range have a constructor with |
| * (begin, extent), which does not match the convention in e.g. python. |
| * We decided to correct it by removing the constructor in HalideIR, |
| * and add it back in TVM's range. |
| */ |
| class Range : public HalideIR::IR::Range { |
| public: |
| /*! \brief constructor */ |
| Range() {} |
| explicit Range(NodePtr<Node> n) : HalideIR::IR::Range(n) {} |
| /*! |
| * \brief constructor by begin and end |
| * \param begin The begin of the range. |
| * \param end The end of the range. |
| */ |
| TVM_DLL Range(Expr begin, Expr end); |
| |
| TVM_DLL static Range make_by_min_extent(Expr min, Expr extent); |
| }; |
| |
| using Region = Array<Range>; |
| |
| /*! |
| * \brief Type of iteration variable. |
| * Each IterVar have a specific type. |
| * |
| * The type of iter var can be overriden via |
| * stage.iter_var_attrs given they are compatible. |
| */ |
| enum IterVarType : int { |
| /*! |
| * \brief Data parallel iteration. |
| * This normally corresponds to axis of Tensor. |
| * Allow all IterVar manipulations. |
| * |
| * \note This does not mean the loop |
| * have to be executed in parallel fashion. |
| */ |
| kDataPar = 0, |
| /*! |
| * \brief The IterVar itself is a thread-index |
| * of a fixed thread launching group. |
| * Note that this is already assumed to be paralellized. |
| * |
| * Disallow: split/fuse/vectorize/parallel |
| */ |
| kThreadIndex = 1, |
| /*! |
| * \brief Communicative reduction. |
| * Cannot be directly parallelized. |
| * |
| * Disallow: parallel/vectorize |
| */ |
| kCommReduce = 2, |
| /*! |
| * \brief Serial loops with loop carry dependency, |
| * the iteration must execute in order. |
| * Cannot be re-ordered. |
| * |
| * Disallow: reorder/parallel/vectorize |
| */ |
| kOrdered = 3, |
| /*! |
| * \brief IterVar is opaque, |
| * |
| * May not corresponds to any generated loop |
| * Disallow all IterVar manipulations and compute_at |
| * |
| * \note This is usually used to implement composite op |
| * or external op, where the |
| */ |
| kOpaque = 4, |
| // The following are possible additional |
| // types that are provided during schedule |
| /*! |
| * \brief The execution is unrolled. |
| */ |
| kUnrolled = 5, |
| /*! |
| * \brief The loop is vectorized. |
| */ |
| kVectorized = 6, |
| /*! |
| * \brief The loop is parallelized. |
| */ |
| kParallelized = 7, |
| /*! |
| * \brief Marks boundary of tensorization intrinsic. |
| */ |
| kTensorized = 8 |
| }; |
| |
| /*! |
| * \brief Iteration Variable, |
| * represents an iteration over an integer interval. |
| */ |
| class IterVar : public NodeRef { |
| public: |
| // construct a new iter var without a domain |
| IterVar() {} |
| // construct from shared ptr. |
| explicit IterVar(NodePtr<Node> n) : NodeRef(n) {} |
| /*! |
| * \brief access the internal node container |
| * \return the pointer to the internal node container |
| */ |
| inline const IterVarNode* operator->() const; |
| /*! |
| * \return the corresponding var in the IterVar. |
| */ |
| inline operator Expr() const; |
| /*! \brief specify container node */ |
| using ContainerType = IterVarNode; |
| }; |
| |
| /*! |
| * \brief Create a new IterVar that represents an axis in thread. |
| * |
| * \param dom Optional, domain of the thread axis. |
| * \param tag The thread tag of the axis. |
| */ |
| TVM_DLL IterVar thread_axis(Range dom, std::string tag); |
| |
| /*! |
| * \brief Create a new IterVar for reduction operations. |
| * |
| * \param dom The domain of the reduction axis. |
| * \param name The name of the reduction axis. |
| */ |
| TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv"); |
| |
| using Domain = Array<Range>; |
| |
| // print functions for expr |
| TVM_DLL std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*) |
| |
| /*! |
| * \brief Dump the node to stderr, used for debug purposes. |
| * \param node The input node |
| */ |
| TVM_DLL void Dump(const NodeRef& node); |
| |
| // definition of Node. |
| /*! |
| * \brief An iteration variable representing an iteration |
| * over a one dimensional interval. |
| */ |
| class IterVarNode : public Node { |
| public: |
| /*! |
| * \brief the domain of iteration, if known, can be None |
| * For the intermediate schedule node, before schedule. |
| */ |
| Range dom; |
| /*! \brief The looping variable */ |
| Var var; |
| /*! \brief The type of the IterVar */ |
| IterVarType iter_type; |
| /*! |
| * \brief additional tag on the iteration variable, |
| * set this if this is binded already to a known thread tag. |
| */ |
| std::string thread_tag; |
| |
| void VisitAttrs(AttrVisitor* v) final { |
| v->Visit("dom", &dom); |
| v->Visit("var", &var); |
| v->Visit("iter_type", &iter_type); |
| v->Visit("thread_tag", &thread_tag); |
| } |
| |
| TVM_DLL static IterVar make(Range dom, Var var, |
| IterVarType iter_type, |
| std::string thread_tag = ""); |
| |
| static constexpr const char* _type_key = "IterVar"; |
| TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node); |
| }; |
| |
| // inline implementations |
| inline const IterVarNode* IterVar::operator->() const { |
| return static_cast<const IterVarNode*>(node_.get()); |
| } |
| |
| inline IterVar::operator Expr() const { |
| return (*this)->var; |
| } |
| |
| inline const char* IterVarType2String(IterVarType t) { |
| switch (t) { |
| case kDataPar: return "DataPar"; |
| case kThreadIndex: return "ThreadIndex"; |
| case kCommReduce: return "CommReduce"; |
| case kOrdered: return "Ordered"; |
| case kOpaque: return "Opaque"; |
| case kUnrolled: return "Unrolled"; |
| case kVectorized: return "Vectorized"; |
| case kParallelized: return "Parallelized"; |
| case kTensorized: return "Tensorized"; |
| } |
| return "Unknown"; |
| } |
| |
| /*! |
| * \brief Construct a new Var expression |
| * \param name_hint The name hint for the expression |
| * \param t The type of the expression |
| */ |
| TVM_DLL Var var(const std::string& name_hint, Type t = Int(32)); |
| |
| /* |
| * \brief Template function to convert Map to unordered_map |
| * Sometimes useful for API gluing when internal uses unordered_map |
| * \param dmap The container map |
| * \return The corresponding unordered_map. |
| * \tparam K the key of the Map. |
| * \tparam V the value of the Map. |
| */ |
| template<typename K, typename V> |
| inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) { |
| std::unordered_map<K, V> ret; |
| for (auto kv : dmap) { |
| ret[kv.first] = kv.second; |
| } |
| return ret; |
| } |
| } // namespace tvm |
| |
| namespace std { |
| template <> |
| struct hash<::tvm::IterVar> { |
| std::size_t operator()(const ::tvm::IterVar& k) const { |
| return k.hash(); |
| } |
| }; |
| } |
| #endif // TVM_EXPR_H_ |