| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file tvm/tir/expr.h |
| * \brief TIR expressions. |
| */ |
| // Acknowledgement: Many low-level IR nodes originate from Halide. |
| #ifndef TVM_TIR_EXPR_H_ |
| #define TVM_TIR_EXPR_H_ |
| |
| #include <tvm/ir/expr.h> |
| #include <tvm/node/container.h> |
| #include <tvm/node/functor.h> |
| #include <tvm/node/node.h> |
| #include <tvm/runtime/c_runtime_api.h> |
| #include <tvm/runtime/data_type.h> |
| #include <tvm/tir/buffer.h> |
| #include <tvm/tir/var.h> |
| |
| #include <algorithm> |
| #include <iostream> |
| #include <limits> |
| #include <string> |
| #include <unordered_map> |
| #include <utility> |
| |
| namespace tvm { |
| namespace tir { |
| |
| using IntImmNode = tvm::IntImmNode; |
| using FloatImmNode = tvm::FloatImmNode; |
| |
| /*! \brief String constants, only used in asserts. */ |
| class StringImmNode : public PrimExprNode { |
| public: |
| /*! \brief The constant value content. */ |
| String value; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &dtype); |
| v->Visit("value", &value); |
| } |
| |
| bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const { |
| return equal(value, other->value); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } |
| |
| static constexpr const char* _type_key = "tir.StringImm"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to StringImmNode. |
| * \sa StringImmNode |
| */ |
| class StringImm : public PrimExpr { |
| public: |
| TVM_DLL StringImm(String value); |
| TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode); |
| }; |
| |
| /*! |
| * \brief Cast value from one data type to another. |
| * \note The lanes of value should keep fixed. |
| */ |
| class CastNode : public PrimExprNode { |
| public: |
| /*! \brief Original data type. */ |
| PrimExpr value; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &dtype); |
| v->Visit("value", &value); |
| } |
| |
| bool SEqualReduce(const CastNode* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal(value, other->value); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(value); |
| } |
| |
| static constexpr const char* _type_key = "tir.Cast"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to CastNode |
| * \sa CastNode |
| */ |
| class Cast : public PrimExpr { |
| public: |
| TVM_DLL Cast(DataType dtype, PrimExpr value); |
| TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode); |
| }; |
| |
| /*! |
| * \brief Base template to implement binary ops. |
| * \tparam T The type of the child class. |
| */ |
| template <typename T> |
| class BinaryOpNode : public PrimExprNode { |
| public: |
| /*! \brief The left operand. */ |
| PrimExpr a; |
| /*! \brief The right operand. */ |
| PrimExpr b; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &(this->dtype)); |
| v->Visit("a", &a); |
| v->Visit("b", &b); |
| } |
| |
| bool SEqualReduce(const T* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(a); |
| hash_reduce(b); |
| } |
| |
| TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode); |
| }; |
| |
| /*! \brief a + b */ |
| class AddNode : public BinaryOpNode<AddNode> { |
| public: |
| static constexpr const char* _type_key = "tir.Add"; |
| }; |
| |
| /*! |
| * \brief Managed reference to AddNode |
| * \sa AddNode |
| */ |
| class Add : public PrimExpr { |
| public: |
| TVM_DLL Add(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode); |
| }; |
| |
| /*! \brief a - b */ |
| class SubNode : public BinaryOpNode<SubNode> { |
| public: |
| static constexpr const char* _type_key = "tir.Sub"; |
| }; |
| |
| /*! |
| * \brief Managed reference to SubNode |
| * \sa SubNode |
| */ |
| class Sub : public PrimExpr { |
| public: |
| TVM_DLL Sub(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode); |
| }; |
| |
| /*! \brief a * b */ |
| class MulNode : public BinaryOpNode<MulNode> { |
| public: |
| static constexpr const char* _type_key = "tir.Mul"; |
| }; |
| |
| /*! |
| * \brief Managed reference to MulNode |
| * \sa MulNode |
| */ |
| class Mul : public PrimExpr { |
| public: |
| TVM_DLL Mul(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode); |
| }; |
| |
| /*! |
| * \brief a / b in the C semnatics. |
| * \note For integer division, C standard uses trunc div. |
| */ |
| class DivNode : public BinaryOpNode<DivNode> { |
| public: |
| static constexpr const char* _type_key = "tir.Div"; |
| }; |
| |
| /*! |
| * \brief Managed reference to DivNode |
| * \sa DivNode |
| */ |
| class Div : public PrimExpr { |
| public: |
| TVM_DLL Div(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode); |
| }; |
| |
| /*! |
| * \brief a % b in the C semnatics. |
| * \note For integer division, C standard uses trunc div. |
| */ |
| class ModNode : public BinaryOpNode<ModNode> { |
| public: |
| static constexpr const char* _type_key = "tir.Mod"; |
| }; |
| |
| /*! |
| * \brief Managed reference to ModNode |
| * \sa ModNode |
| */ |
| class Mod : public PrimExpr { |
| public: |
| TVM_DLL Mod(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode); |
| }; |
| |
| /*! \brief Floor division, floor(a/b) */ |
| class FloorDivNode : public BinaryOpNode<FloorDivNode> { |
| public: |
| static constexpr const char* _type_key = "tir.FloorDiv"; |
| }; |
| |
| /*! |
| * \brief Managed reference to FloorDivNode |
| * \sa FloorDivNode |
| */ |
| class FloorDiv : public PrimExpr { |
| public: |
| TVM_DLL FloorDiv(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode); |
| }; |
| |
| /*! \brief The remainder of the floordiv */ |
| class FloorModNode : public BinaryOpNode<FloorModNode> { |
| public: |
| static constexpr const char* _type_key = "tir.FloorMod"; |
| }; |
| |
| /*! |
| * \brief Managed reference to FloorModNode |
| * \sa FloorModNode |
| */ |
| class FloorMod : public PrimExpr { |
| public: |
| TVM_DLL FloorMod(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode); |
| }; |
| |
| /*! \brief min(a, b) */ |
| class MinNode : public BinaryOpNode<MinNode> { |
| public: |
| static constexpr const char* _type_key = "tir.Min"; |
| }; |
| |
| /*! |
| * \brief Managed reference to MinNode |
| * \sa MinNode |
| */ |
| class Min : public PrimExpr { |
| public: |
| TVM_DLL Min(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode); |
| }; |
| |
| /*! \brief max(a, b) */ |
| class MaxNode : public BinaryOpNode<MaxNode> { |
| public: |
| static constexpr const char* _type_key = "tir.Max"; |
| }; |
| |
| /*! |
| * \brief Managed reference to MaxNode |
| * \sa MaxNode |
| */ |
| class Max : public PrimExpr { |
| public: |
| TVM_DLL Max(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode); |
| }; |
| |
| /*! |
| * \brief Base template to implement comparison ops. |
| * \tparam T The type of the child class. |
| */ |
| template <typename T> |
| class CmpOpNode : public PrimExprNode { |
| public: |
| /*! \brief The left operand. */ |
| PrimExpr a; |
| /*! \brief The right operand. */ |
| PrimExpr b; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &(this->dtype)); |
| v->Visit("a", &a); |
| v->Visit("b", &b); |
| } |
| |
| bool SEqualReduce(const T* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(a); |
| hash_reduce(b); |
| } |
| |
| TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode); |
| }; |
| |
| /*! \brief a == b */ |
| class EQNode : public CmpOpNode<EQNode> { |
| public: |
| static constexpr const char* _type_key = "tir.EQ"; |
| }; |
| |
| /*! |
| * \brief Managed reference to EQNode |
| * \sa EQNode |
| */ |
| class EQ : public PrimExpr { |
| public: |
| TVM_DLL EQ(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode); |
| }; |
| |
| /*! \brief a != b */ |
| class NENode : public CmpOpNode<NENode> { |
| public: |
| static constexpr const char* _type_key = "tir.NE"; |
| }; |
| |
| /*! |
| * \brief Managed reference to NENode |
| * \sa NENode |
| */ |
| class NE : public PrimExpr { |
| public: |
| TVM_DLL NE(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode); |
| }; |
| |
| /*! \brief a < b */ |
| class LTNode : public CmpOpNode<LTNode> { |
| public: |
| static constexpr const char* _type_key = "tir.LT"; |
| }; |
| |
| /*! |
| * \brief Managed reference to LTNode |
| * \sa LTNode |
| */ |
| class LT : public PrimExpr { |
| public: |
| TVM_DLL LT(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode); |
| }; |
| |
| /*! \brief a <= b */ |
| struct LENode : public CmpOpNode<LENode> { |
| public: |
| static constexpr const char* _type_key = "tir.LE"; |
| }; |
| |
| /*! |
| * \brief Managed reference to LENode |
| * \sa LENode |
| */ |
| class LE : public PrimExpr { |
| public: |
| TVM_DLL LE(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode); |
| }; |
| |
| /*! \brief a > b */ |
| class GTNode : public CmpOpNode<GTNode> { |
| public: |
| static constexpr const char* _type_key = "tir.GT"; |
| }; |
| |
| /*! |
| * \brief Managed reference to GTNode |
| * \sa GTNode |
| */ |
| class GT : public PrimExpr { |
| public: |
| TVM_DLL GT(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode); |
| }; |
| |
| /*! \brief a >= b */ |
| class GENode : public CmpOpNode<GENode> { |
| public: |
| static constexpr const char* _type_key = "tir.GE"; |
| }; |
| |
| /*! |
| * \brief Managed reference to GENode |
| * \sa GENode |
| */ |
| class GE : public PrimExpr { |
| public: |
| TVM_DLL GE(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode); |
| }; |
| |
| /*! \brief a && b */ |
| class AndNode : public PrimExprNode { |
| public: |
| /*! \brief The left operand. */ |
| PrimExpr a; |
| /*! \brief The right operand. */ |
| PrimExpr b; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &(this->dtype)); |
| v->Visit("a", &a); |
| v->Visit("b", &b); |
| } |
| |
| bool SEqualReduce(const AndNode* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(a); |
| hash_reduce(b); |
| } |
| |
| static constexpr const char* _type_key = "tir.And"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to AndNode |
| * \sa AndNode |
| */ |
| class And : public PrimExpr { |
| public: |
| TVM_DLL And(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode); |
| }; |
| |
| /*! \brief a || b */ |
| class OrNode : public PrimExprNode { |
| public: |
| /*! \brief The left operand. */ |
| PrimExpr a; |
| /*! \brief The right operand. */ |
| PrimExpr b; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &dtype); |
| v->Visit("a", &a); |
| v->Visit("b", &b); |
| } |
| |
| bool SEqualReduce(const OrNode* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(a); |
| hash_reduce(b); |
| } |
| |
| static constexpr const char* _type_key = "tir.Or"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to OrNode |
| * \sa OrNode |
| */ |
| class Or : public PrimExpr { |
| public: |
| TVM_DLL Or(PrimExpr a, PrimExpr b); |
| TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode); |
| }; |
| |
| /*! \brief !a */ |
| class NotNode : public PrimExprNode { |
| public: |
| /*! \brief The input operand. */ |
| PrimExpr a; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &dtype); |
| v->Visit("a", &a); |
| } |
| |
| bool SEqualReduce(const NotNode* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal(a, other->a); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(a); |
| } |
| |
| static constexpr const char* _type_key = "tir.Not"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to NotNode |
| * \sa NotNode |
| */ |
| class Not : public PrimExpr { |
| public: |
| TVM_DLL Not(PrimExpr a); |
| TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode); |
| }; |
| |
| /*! |
| * \brief return true_value if condition is true, otherwise return false_value. |
| * \note Both true_value and false_value could be evaluated |
| * regardless of the condition value. |
| * Do not use it to guard against out of bound access, |
| * please use if_then_else instead. |
| */ |
| class SelectNode : public PrimExprNode { |
| public: |
| /*! \brief The condition */ |
| PrimExpr condition; |
| /*! \brief value to be returned when condition is true. */ |
| PrimExpr true_value; |
| /*! \brief value to be returned when condition is false. */ |
| PrimExpr false_value; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &dtype); |
| v->Visit("condition", &condition); |
| v->Visit("true_value", &true_value); |
| v->Visit("false_value", &false_value); |
| } |
| |
| bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal(condition, other->condition) && |
| equal(true_value, other->true_value) && equal(false_value, other->false_value); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(condition); |
| hash_reduce(true_value); |
| hash_reduce(false_value); |
| } |
| |
| static constexpr const char* _type_key = "tir.Select"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to SelectNode |
| * \sa SelectNode |
| */ |
| class Select : public PrimExpr { |
| public: |
| TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value); |
| |
| TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode); |
| }; |
| |
| /*! |
| * \brief Load value from the high dimension buffer. |
| * |
| * \code |
| * |
| * value = buffer[i, j]; |
| * |
| * \endcode |
| * \sa BufferStore |
| */ |
| class BufferLoadNode : public PrimExprNode { |
| public: |
| /*! \brief The buffer variable. */ |
| Buffer buffer; |
| /*! \brief The indices location to be loaded. */ |
| Array<PrimExpr> indices; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &(this->dtype)); |
| v->Visit("buffer", &buffer); |
| v->Visit("indices", &indices); |
| } |
| |
| bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal(buffer, other->buffer) && |
| equal(indices, other->indices); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(buffer); |
| hash_reduce(indices); |
| } |
| |
| static constexpr const char* _type_key = "tir.BufferLoad"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to BufferLoadNode. |
| * \sa BufferLoadNode |
| */ |
| class BufferLoad : public PrimExpr { |
| public: |
| TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices); |
| TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); |
| }; |
| |
| /*! |
| * \brief Load value from the result produced by the producer. |
| * |
| * \note This node only appears in high-level DSLs that are built on top of the TIR. |
| * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower |
| * this node before TIR transformations. |
| * |
| * \sa ProducerLoad, DataProducerNode |
| */ |
| class ProducerLoadNode : public PrimExprNode { |
| public: |
| /*! \brief The buffer producer. */ |
| DataProducer producer; |
| /*! \brief The location arguments. */ |
| Array<PrimExpr> indices; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &(this->dtype)); |
| v->Visit("producer", &producer); |
| v->Visit("indices", &indices); |
| } |
| |
| bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal(producer, other->producer) && |
| equal(indices, other->indices); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(producer); |
| hash_reduce(indices); |
| } |
| |
| static constexpr const char* _type_key = "tir.ProducerLoad"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to ProducerLoadNode. |
| * \sa ProducerLoadNode |
| */ |
| class ProducerLoad : public PrimExpr { |
| public: |
| TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> indices); |
| |
| TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode); |
| }; |
| |
| /*! |
| * \brief Load the value from buffer_var. |
| * |
| * Equivalent to ((DType*)buffer_var)[index] |
| * where DType is the type specified by type().element_of(). |
| * |
| * For example, if type = float32x3, then the load will corresponds to |
| * |
| * \code |
| * |
| * auto buffer = static_cast<float*>(buffer_var); |
| * auto loaded_val = float32x3(buffer[index.v0], buffer[index.v1], buffer[index.v2]); |
| * |
| * \endcode |
| */ |
| class LoadNode : public PrimExprNode { |
| public: |
| /*! \brief The buffer variable. */ |
| Var buffer_var; |
| /*! \brief The index locations to be loaded. */ |
| PrimExpr index; |
| /*! \brief The predicate to mask which lanes would be loaded. */ |
| PrimExpr predicate; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &dtype); |
| v->Visit("buffer_var", &buffer_var); |
| v->Visit("index", &index); |
| v->Visit("predicate", &predicate); |
| } |
| |
| bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal(buffer_var, other->buffer_var) && |
| equal(index, other->index) && equal(predicate, other->predicate); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(buffer_var); |
| hash_reduce(index); |
| hash_reduce(predicate); |
| } |
| |
| static constexpr const char* _type_key = "tir.Load"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to LoadNode |
| * \sa LoadNode |
| */ |
| class Load : public PrimExpr { |
| public: |
| TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate); |
| TVM_DEFINE_OBJECT_REF_METHODS(Load, PrimExpr, LoadNode); |
| }; |
| |
| /*! |
| * \brief Construct a vector with lanes elements |
| * where its i-th element equals base + i * stride. |
| * This is useful to construct a index for a continuous vector load. |
| * |
| * Examples: |
| * - ramp(0, 1, 3) = [0, 1, 2] |
| * - ramp(1, 2, 4) = [1, 3, 5, 7] |
| */ |
| class RampNode : public PrimExprNode { |
| public: |
| /*! \brief The base value. */ |
| PrimExpr base; |
| /*! \brief The stride of each step. */ |
| PrimExpr stride; |
| /*! \brief Total number of lanes. */ |
| int lanes; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &dtype); |
| v->Visit("base", &base); |
| v->Visit("stride", &stride); |
| v->Visit("lanes", &lanes); |
| } |
| |
| bool SEqualReduce(const RampNode* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal(base, other->base) && equal(stride, other->stride) && |
| equal(lanes, other->lanes); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(base); |
| hash_reduce(stride); |
| hash_reduce(lanes); |
| } |
| |
| static constexpr const char* _type_key = "tir.Ramp"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to RampNode |
| * \sa RampNode |
| */ |
| class Ramp : public PrimExpr { |
| public: |
| TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes); |
| TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode); |
| }; |
| |
| /*! \brief Create a vector where all the elements are value. */ |
| class BroadcastNode : public PrimExprNode { |
| public: |
| /*! \brief The base value. */ |
| PrimExpr value; |
| /*! \brief The number of lanes. */ |
| int lanes; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &dtype); |
| v->Visit("value", &value); |
| v->Visit("lanes", &lanes); |
| } |
| |
| bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal(value, other->value) && equal(lanes, other->lanes); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(value); |
| hash_reduce(lanes); |
| } |
| |
| static constexpr const char* _type_key = "tir.Broadcast"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to BroadcastNode |
| * \sa BroadcastNode |
| */ |
| class Broadcast : public PrimExpr { |
| public: |
| TVM_DLL Broadcast(PrimExpr value, int lanes); |
| TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode); |
| }; |
| |
| /*! |
| * \brief Let binding. Bind var to value then evaluate body. |
| */ |
| class LetNode : public PrimExprNode { |
| public: |
| /*! \brief The variable. */ |
| Var var; |
| /*! \brief The value to be binded. */ |
| PrimExpr value; |
| /*! \brief The result expression. */ |
| PrimExpr body; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &dtype); |
| v->Visit("var", &var); |
| v->Visit("value", &value); |
| v->Visit("body", &body); |
| } |
| |
| bool SEqualReduce(const LetNode* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal.DefEqual(var, other->var) && |
| equal(value, other->value) && equal(body, other->body); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce.DefHash(var); |
| hash_reduce(value); |
| hash_reduce(body); |
| } |
| |
| static constexpr const char* _type_key = "tir.Let"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to LetNode |
| * \sa LetNode |
| */ |
| class Let : public PrimExpr { |
| public: |
| TVM_DLL Let(Var var, PrimExpr value, PrimExpr body); |
| TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode); |
| }; |
| |
| /*! |
| * \brief Call node. |
| */ |
| class CallNode : public PrimExprNode { |
| public: |
| /*! |
| * \brief The operator(function) being invoked |
| * |
| * - It can be tvm::Op which corresponds to the primitive operators(intrinsics). |
| * - It can also be another function in the IRModule (GlobalVar). |
| */ |
| RelayExpr op; |
| |
| /*! \brief The arguments. */ |
| Array<PrimExpr> args; |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &dtype); |
| v->Visit("op", &op); |
| v->Visit("args", &args); |
| } |
| |
| bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(op); |
| hash_reduce(args); |
| } |
| |
| static constexpr const char* _type_key = "tir.Call"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to CallNode |
| * \sa CallNode |
| */ |
| class Call : public PrimExpr { |
| public: |
| TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args); |
| TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode); |
| }; |
| |
| /*! |
| * \brief Shuffle instruction. |
| * vec = concat(vectors) |
| * result = (vec[indices[0]], vec[indices[1]] ...) |
| */ |
| class ShuffleNode : public PrimExprNode { |
| public: |
| /*! \brief the input vectors. */ |
| Array<PrimExpr> vectors; |
| /*! \brief The indices of each element. */ |
| Array<PrimExpr> indices; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("vectors", &vectors); |
| v->Visit("indices", &indices); |
| } |
| |
| bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype) && equal(vectors, other->vectors) && |
| equal(indices, other->indices); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(vectors); |
| hash_reduce(indices); |
| } |
| |
| static constexpr const char* _type_key = "tir.Shuffle"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to ShuffleNode |
| * \sa ShuffleNode |
| */ |
| class Shuffle : public PrimExpr { |
| public: |
| TVM_DLL Shuffle(Array<PrimExpr> vectors, Array<PrimExpr> indices); |
| TVM_DLL static PrimExpr Concat(Array<PrimExpr> vectors); |
| TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index); |
| |
| TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode); |
| }; |
| |
| // Reduce operator |
| /*! |
| * \brief A commutative reducer node to represent a commutative |
| * binary operator with identity element |
| */ |
| class CommReducerNode : public Object { |
| public: |
| /*! \brief The left argument of reducer */ |
| Array<Var> lhs; |
| /*! \brief The right argument of reducer */ |
| Array<Var> rhs; |
| /*! \brief The result of reducer */ |
| Array<PrimExpr> result; |
| /*! |
| * \brief The identity element of reducer, which leaves other |
| * elements unchanged when combined with it, with respect to |
| * the binary operation of this reducer uses. |
| */ |
| Array<PrimExpr> identity_element; |
| /*! \brief Function call operator to combine a and b */ |
| Array<PrimExpr> operator()(Array<PrimExpr> a, Array<PrimExpr> b) const; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("lhs", &lhs); |
| v->Visit("rhs", &rhs); |
| v->Visit("result", &result); |
| v->Visit("identity_element", &identity_element); |
| } |
| |
| bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const { |
| return equal.DefEqual(lhs, other->lhs) && equal.DefEqual(rhs, other->rhs) && |
| equal(result, other->result) && equal(identity_element, other->identity_element); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce.DefHash(lhs); |
| hash_reduce.DefHash(rhs); |
| hash_reduce(result); |
| hash_reduce(identity_element); |
| } |
| |
| static constexpr const char* _type_key = "tir.CommReducer"; |
| static constexpr const bool _type_has_method_sequal_reduce = true; |
| static constexpr const bool _type_has_method_shash_reduce = true; |
| TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object); |
| }; |
| |
| /*! |
| * \brief Managed reference to CommReducerNode |
| * \sa CommReducerNode |
| */ |
| class CommReducer : public ObjectRef { |
| public: |
| TVM_DLL CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result, |
| Array<PrimExpr> identity_element); |
| |
| TVM_DEFINE_OBJECT_REF_METHODS(CommReducer, ObjectRef, CommReducerNode); |
| }; |
| |
| /*! \brief Reduction operator operator */ |
| class ReduceNode : public PrimExprNode { |
| public: |
| /*! \brief The commutative combiner */ |
| CommReducer combiner; |
| /*! \brief The source operand */ |
| Array<PrimExpr> source; |
| /*! \brief The init operand */ |
| Array<PrimExpr> init; |
| /*! \brief The reduction axis */ |
| Array<IterVar> axis; |
| /*! |
| * \brief Predicate on the reduction |
| * Only add the body to reduction if condition is true. |
| */ |
| PrimExpr condition; |
| /*! \brief the index of this reduce node */ |
| int value_index; |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("dtype", &dtype); |
| v->Visit("combiner", &combiner); |
| v->Visit("source", &source); |
| v->Visit("init", &init); |
| v->Visit("axis", &axis); |
| v->Visit("condition", &condition); |
| v->Visit("value_index", &value_index); |
| } |
| |
| bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const { |
| // check axis first so IterVars can define the necessary variables. |
| return equal(dtype, other->dtype) && equal(axis, other->axis) && |
| equal(combiner, other->combiner) && equal(source, other->source) && |
| equal(init, other->init) && equal(condition, other->condition) && |
| equal(value_index, other->value_index); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const { |
| hash_reduce(dtype); |
| hash_reduce(axis); |
| hash_reduce(combiner); |
| hash_reduce(source); |
| hash_reduce(init); |
| hash_reduce(condition); |
| hash_reduce(value_index); |
| } |
| |
| static constexpr const char* _type_key = "tir.Reduce"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to ReduceNode |
| * \sa ReduceNode |
| */ |
| class Reduce : public PrimExpr { |
| public: |
| TVM_DLL Reduce(CommReducer combiner, Array<PrimExpr> src, Array<IterVar> rdom, PrimExpr condition, |
| int value_index, Array<PrimExpr> init); |
| |
| TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode); |
| }; |
| |
| /*! \brief Any shape. */ |
| class AnyNode : public PrimExprNode { |
| public: |
| void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); } |
| |
| bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const { |
| return equal(dtype, other->dtype); |
| } |
| |
| void SHashReduce(SHashReducer hash_reduce) const {} |
| |
| /*! \brief Convert to var. */ |
| Var ToVar() const { return Var("any_dim", DataType::Int(32)); } |
| |
| static constexpr const char* _type_key = "tir.Any"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode); |
| }; |
| |
| /*! |
| * \brief Managed reference to AnyNode |
| * \sa AnyNode |
| */ |
| class Any : public PrimExpr { |
| public: |
| TVM_DLL Any(); |
| |
| TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode); |
| }; |
| |
| /* |
| * \brief Template function to convert Map to unordered_map |
| * Sometimes useful for API gluing when internal uses unordered_map |
| * \param dmap The container map |
| * \return The corresponding unordered_map. |
| * \tparam K the key of the Map. |
| * \tparam V the value of the Map. |
| */ |
| template <typename K, typename V> |
| inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) { |
| std::unordered_map<K, V> ret; |
| for (auto kv : dmap) { |
| ret[kv.first] = kv.second; |
| } |
| return ret; |
| } |
| } // namespace tir |
| } // namespace tvm |
| |
| namespace std { |
| template <> |
| struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {}; |
| } // namespace std |
| #endif // TVM_TIR_EXPR_H_ |