blob: cd51108b0d233e749484acb5e442c21804b1ffbb [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/tirx/expr.h
* \brief TIR expressions.
*/
// Acknowledgement: Many low-level IR nodes originate from Halide.
#ifndef TVM_TIR_EXPR_H_
#define TVM_TIR_EXPR_H_
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/string.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/cow.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/node_functor.h>
#include <tvm/runtime/base.h>
#include <tvm/runtime/data_type.h>
#include <tvm/tirx/buffer.h>
#include <tvm/tirx/var.h>
#include <algorithm>
#include <iostream>
#include <limits>
#include <string>
#include <unordered_map>
#include <utility>
namespace tvm {
namespace tirx {
using IntImmNode = tvm::IntImmNode;
using FloatImmNode = tvm::FloatImmNode;
/*! \brief ffi::String constants, only used in asserts. */
class StringImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
ffi::String value;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<StringImmNode>().def_ro("value", &StringImmNode::value);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.StringImm", StringImmNode, PrimExprNode);
};
/*!
* \brief Managed reference to StringImmNode.
* \sa StringImmNode
*/
class StringImm : public PrimExpr {
public:
TVM_DLL StringImm(ffi::String value, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StringImm, PrimExpr, StringImmNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode);
};
/*!
* \brief Cast value from one data type to another.
* \note The lanes of value should keep fixed.
*/
class CastNode : public PrimExprNode {
public:
/*! \brief Original data type. */
PrimExpr value;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<CastNode>().def_ro("value", &CastNode::value);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Cast", CastNode, PrimExprNode);
};
/*!
* \brief Managed reference to CastNode
* \sa CastNode
*/
class Cast : public PrimExpr {
public:
TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Cast, PrimExpr, CastNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode);
};
/*!
* \brief Base template to implement binary ops.
* \tparam T The type of the child class.
*/
template <typename T>
class BinaryOpNode : public PrimExprNode {
public:
/*! \brief The left operand. */
PrimExpr a;
/*! \brief The right operand. */
PrimExpr b;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<T>().def_ro("a", &T::a).def_ro("b", &T::b);
}
static const constexpr int _type_child_slots [[maybe_unused]] = 0;
static const constexpr bool _type_final [[maybe_unused]] = true;
TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, PrimExprNode);
};
/*! \brief a + b */
class AddNode : public BinaryOpNode<AddNode> {
public:
static constexpr const char* _type_key = "tirx.Add";
};
/*!
* \brief Managed reference to AddNode
* \sa AddNode
*/
class Add : public PrimExpr {
public:
TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Add, PrimExpr, AddNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode);
};
/*! \brief a - b */
class SubNode : public BinaryOpNode<SubNode> {
public:
static constexpr const char* _type_key = "tirx.Sub";
};
/*!
* \brief Managed reference to SubNode
* \sa SubNode
*/
class Sub : public PrimExpr {
public:
TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Sub, PrimExpr, SubNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode);
};
/*! \brief a * b */
class MulNode : public BinaryOpNode<MulNode> {
public:
static constexpr const char* _type_key = "tirx.Mul";
};
/*!
* \brief Managed reference to MulNode
* \sa MulNode
*/
class Mul : public PrimExpr {
public:
TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mul, PrimExpr, MulNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode);
};
/*!
* \brief a / b in the C semnatics.
* \note For integer division, C standard uses trunc div.
*/
class DivNode : public BinaryOpNode<DivNode> {
public:
static constexpr const char* _type_key = "tirx.Div";
};
/*!
* \brief Managed reference to DivNode
* \sa DivNode
*/
class Div : public PrimExpr {
public:
TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Div, PrimExpr, DivNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode);
};
/*!
* \brief a % b in the C semnatics.
* \note For integer division, C standard uses trunc div.
*/
class ModNode : public BinaryOpNode<ModNode> {
public:
static constexpr const char* _type_key = "tirx.Mod";
};
/*!
* \brief Managed reference to ModNode
* \sa ModNode
*/
class Mod : public PrimExpr {
public:
TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Mod, PrimExpr, ModNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode);
};
/*! \brief Floor division, floor(a/b) */
class FloorDivNode : public BinaryOpNode<FloorDivNode> {
public:
static constexpr const char* _type_key = "tirx.FloorDiv";
};
/*!
* \brief Managed reference to FloorDivNode
* \sa FloorDivNode
*/
class FloorDiv : public PrimExpr {
public:
TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloorDiv, PrimExpr, FloorDivNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode);
};
/*! \brief The remainder of the floordiv */
class FloorModNode : public BinaryOpNode<FloorModNode> {
public:
static constexpr const char* _type_key = "tirx.FloorMod";
};
/*!
* \brief Managed reference to FloorModNode
* \sa FloorModNode
*/
class FloorMod : public PrimExpr {
public:
TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloorMod, PrimExpr, FloorModNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode);
};
/*! \brief min(a, b) */
class MinNode : public BinaryOpNode<MinNode> {
public:
static constexpr const char* _type_key = "tirx.Min";
};
/*!
* \brief Managed reference to MinNode
* \sa MinNode
*/
class Min : public PrimExpr {
public:
TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Min, PrimExpr, MinNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode);
};
/*! \brief max(a, b) */
class MaxNode : public BinaryOpNode<MaxNode> {
public:
static constexpr const char* _type_key = "tirx.Max";
};
/*!
* \brief Managed reference to MaxNode
* \sa MaxNode
*/
class Max : public PrimExpr {
public:
TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Max, PrimExpr, MaxNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode);
};
/*!
* \brief Base template to implement comparison ops.
* \tparam T The type of the child class.
*/
template <typename T>
class CmpOpNode : public PrimExprNode {
public:
/*! \brief The left operand. */
PrimExpr a;
/*! \brief The right operand. */
PrimExpr b;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<T>().def_ro("a", &T::a).def_ro("b", &T::b);
}
static const constexpr int _type_child_slots [[maybe_unused]] = 0;
static const constexpr bool _type_final [[maybe_unused]] = true;
TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(T, PrimExprNode);
};
/*! \brief a == b */
class EQNode : public CmpOpNode<EQNode> {
public:
static constexpr const char* _type_key = "tirx.EQ";
};
/*!
* \brief Managed reference to EQNode
* \sa EQNode
*/
class EQ : public PrimExpr {
public:
TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(EQ, PrimExpr, EQNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode);
};
/*! \brief a != b */
class NENode : public CmpOpNode<NENode> {
public:
static constexpr const char* _type_key = "tirx.NE";
};
/*!
* \brief Managed reference to NENode
* \sa NENode
*/
class NE : public PrimExpr {
public:
TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(NE, PrimExpr, NENode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode);
};
/*! \brief a < b */
class LTNode : public CmpOpNode<LTNode> {
public:
static constexpr const char* _type_key = "tirx.LT";
};
/*!
* \brief Managed reference to LTNode
* \sa LTNode
*/
class LT : public PrimExpr {
public:
TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LT, PrimExpr, LTNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode);
};
/*! \brief a <= b */
struct LENode : public CmpOpNode<LENode> {
public:
static constexpr const char* _type_key = "tirx.LE";
};
/*!
* \brief Managed reference to LENode
* \sa LENode
*/
class LE : public PrimExpr {
public:
TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LE, PrimExpr, LENode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode);
};
/*! \brief a > b */
class GTNode : public CmpOpNode<GTNode> {
public:
static constexpr const char* _type_key = "tirx.GT";
};
/*!
* \brief Managed reference to GTNode
* \sa GTNode
*/
class GT : public PrimExpr {
public:
TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GT, PrimExpr, GTNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode);
};
/*! \brief a >= b */
class GENode : public CmpOpNode<GENode> {
public:
static constexpr const char* _type_key = "tirx.GE";
};
/*!
* \brief Managed reference to GENode
* \sa GENode
*/
class GE : public PrimExpr {
public:
TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GE, PrimExpr, GENode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode);
};
/*! \brief a && b */
class AndNode : public PrimExprNode {
public:
/*! \brief The left operand. */
PrimExpr a;
/*! \brief The right operand. */
PrimExpr b;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<AndNode>().def_ro("a", &AndNode::a).def_ro("b", &AndNode::b);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.And", AndNode, PrimExprNode);
};
/*!
* \brief Managed reference to AndNode
* \sa AndNode
*/
class And : public PrimExpr {
public:
TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(And, PrimExpr, AndNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode);
};
/*! \brief a || b */
class OrNode : public PrimExprNode {
public:
/*! \brief The left operand. */
PrimExpr a;
/*! \brief The right operand. */
PrimExpr b;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<OrNode>().def_ro("a", &OrNode::a).def_ro("b", &OrNode::b);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Or", OrNode, PrimExprNode);
};
/*!
* \brief Managed reference to OrNode
* \sa OrNode
*/
class Or : public PrimExpr {
public:
TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Or, PrimExpr, OrNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode);
};
/*! \brief !a */
class NotNode : public PrimExprNode {
public:
/*! \brief The input operand. */
PrimExpr a;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<NotNode>().def_ro("a", &NotNode::a);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Not", NotNode, PrimExprNode);
};
/*!
* \brief Managed reference to NotNode
* \sa NotNode
*/
class Not : public PrimExpr {
public:
TVM_DLL Not(PrimExpr a, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Not, PrimExpr, NotNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode);
};
/*!
* \brief return true_value if condition is true, otherwise return false_value.
* \note Both true_value and false_value could be evaluated
* regardless of the condition value.
* Do not use it to guard against out of bound access,
* please use if_then_else instead.
*/
class SelectNode : public PrimExprNode {
public:
/*! \brief The condition */
PrimExpr condition;
/*! \brief value to be returned when condition is true. */
PrimExpr true_value;
/*! \brief value to be returned when condition is false. */
PrimExpr false_value;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<SelectNode>()
.def_ro("condition", &SelectNode::condition)
.def_ro("true_value", &SelectNode::true_value)
.def_ro("false_value", &SelectNode::false_value);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Select", SelectNode, PrimExprNode);
};
/*!
* \brief Managed reference to SelectNode
* \sa SelectNode
*/
class Select : public PrimExpr {
public:
TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Select, PrimExpr, SelectNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode);
};
/*!
* \brief Load value from the high dimension buffer.
*
* \code
*
* value = buffer[i, j];
*
* \endcode
* \sa BufferStore
*/
class BufferLoadNode : public PrimExprNode {
public:
/*! \brief The buffer variable. */
Buffer buffer;
/*! \brief The indices location to be loaded. */
ffi::Array<PrimExpr> indices;
/*! \brief The predicate mask for loading values. */
ffi::Optional<PrimExpr> predicate;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<BufferLoadNode>()
.def_ro("buffer", &BufferLoadNode::buffer)
.def_ro("indices", &BufferLoadNode::indices)
.def_ro("predicate", &BufferLoadNode::predicate);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.BufferLoad", BufferLoadNode, PrimExprNode);
private:
/*! \brief Set the dtype based on the buffer/indices
*
* Usually, the BufferLoad's dtype will be the same dtype as the
* buffer. This may have a different number of lanes than the
* buffer's dtype if index values have more than 1 lane.
*
* This function should only be called during construction and after
* CopyOnWrite. Friend class used here to restrict usage.
*/
void LegalizeDType();
friend class BufferLoad;
friend class CustomDatatypesLowerer;
friend class VectorTypeRewriter;
friend class Vectorizer;
};
/*!
* \brief Managed reference to BufferLoadNode.
* \sa BufferLoadNode
*/
class BufferLoad : public PrimExpr {
public:
TVM_DLL explicit BufferLoad(Buffer buffer, ffi::Array<PrimExpr> indices,
ffi::Optional<PrimExpr> predicate = std::nullopt, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BufferLoad, PrimExpr, BufferLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};
/*!
* \brief Load value from the result produced by the producer.
*
* \note This node only appears in high-level DSLs that are built on top of the TIR.
* It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
* this node before TIR transformations.
*
* \sa ProducerLoad, DataProducerNode
*/
class ProducerLoadNode : public PrimExprNode {
public:
/*! \brief The buffer producer. */
DataProducer producer;
/*! \brief The location arguments. */
ffi::Array<PrimExpr> indices;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ProducerLoadNode>()
.def_ro("producer", &ProducerLoadNode::producer)
.def_ro("indices", &ProducerLoadNode::indices);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.ProducerLoad", ProducerLoadNode, PrimExprNode);
};
/*!
* \brief Managed reference to ProducerLoadNode.
* \sa ProducerLoadNode
*/
class ProducerLoad : public PrimExpr {
public:
TVM_DLL explicit ProducerLoad(DataProducer producer, ffi::Array<PrimExpr> indices,
Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ProducerLoad, PrimExpr, ProducerLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode);
};
/*!
* \brief Construct a vector with lanes elements
* where its i-th element equals base + i * stride.
* This is useful to construct a index for a continuous vector load.
*
* Examples:
* - ramp(0, 1, 3) = [0, 1, 2]
* - ramp(1, 2, 4) = [1, 3, 5, 7]
*/
class RampNode : public PrimExprNode {
public:
/*! \brief The base value. */
PrimExpr base;
/*! \brief The stride of each step. */
PrimExpr stride;
/*! \brief Total number of lanes. */
PrimExpr lanes;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<RampNode>()
.def_ro("base", &RampNode::base)
.def_ro("stride", &RampNode::stride)
.def_ro("lanes", &RampNode::lanes);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Ramp", RampNode, PrimExprNode);
};
/*!
* \brief Managed reference to RampNode
* \sa RampNode
*/
class Ramp : public PrimExpr {
public:
TVM_DLL Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Ramp, PrimExpr, RampNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode);
};
/*! \brief Create a vector where all the elements are value. */
class BroadcastNode : public PrimExprNode {
public:
/*! \brief The base value. */
PrimExpr value;
/*! \brief The number of lanes. */
PrimExpr lanes;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<BroadcastNode>()
.def_ro("value", &BroadcastNode::value)
.def_ro("lanes", &BroadcastNode::lanes);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Broadcast", BroadcastNode, PrimExprNode);
};
/*!
* \brief Managed reference to BroadcastNode
* \sa BroadcastNode
*/
class Broadcast : public PrimExpr {
public:
TVM_DLL Broadcast(PrimExpr value, PrimExpr lanes, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Broadcast, PrimExpr, BroadcastNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode);
};
/*!
* \brief Let binding. Bind var to value then evaluate body.
*/
class LetNode : public PrimExprNode {
public:
/*! \brief The variable. */
Var var;
/*! \brief The value to be binded. */
PrimExpr value;
/*! \brief The result expression. */
PrimExpr body;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<LetNode>()
// TODO(tqchen): use SEqHashDefNonRecursive after the next pypi tvm-ffi release
.def_ro("var", &LetNode::var, refl::AttachFieldFlag::SEqHashDefRecursive())
.def_ro("value", &LetNode::value)
.def_ro("body", &LetNode::body);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Let", LetNode, PrimExprNode);
};
/*!
* \brief Managed reference to LetNode
* \sa LetNode
*/
class Let : public PrimExpr {
public:
TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Let, PrimExpr, LetNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode);
};
/*!
* \brief Call node.
*/
class CallNode : public PrimExprNode {
public:
/*!
* \brief The operator(function) being invoked
*
* - It can be tvm::Op which corresponds to the primitive operators(intrinsics).
* - It can also be another function in the IRModule (GlobalVar).
*/
RelaxExpr op;
/*! \brief The arguments. */
ffi::Array<PrimExpr> args;
/*! \brief The additional attributes. */
Attrs attrs;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<CallNode>()
.def_ro("op", &CallNode::op)
.def_ro("args", &CallNode::args)
.def_ro("attrs", &CallNode::attrs);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Call", CallNode, PrimExprNode);
};
/*!
* \brief Managed reference to CallNode
* \sa CallNode
*/
class Call : public PrimExpr {
public:
TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Attrs attrs = Attrs(),
Span span = Span());
TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Span span);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
};
/*!
* \brief Shuffle instruction.
* vec = concat(vectors)
* result = (vec[indices[0]], vec[indices[1]] ...)
*/
class ShuffleNode : public PrimExprNode {
public:
/*! \brief the input vectors. */
ffi::Array<PrimExpr> vectors;
/*! \brief The indices of each element. */
ffi::Array<PrimExpr> indices;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ShuffleNode>()
.def_ro("vectors", &ShuffleNode::vectors)
.def_ro("indices", &ShuffleNode::indices);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Shuffle", ShuffleNode, PrimExprNode);
};
/*!
* \brief Managed reference to ShuffleNode
* \sa ShuffleNode
*/
class Shuffle : public PrimExpr {
public:
TVM_DLL Shuffle(ffi::Array<PrimExpr> vectors, ffi::Array<PrimExpr> indices, Span span = Span());
TVM_DLL static PrimExpr Concat(ffi::Array<PrimExpr> vectors, Span span = Span());
TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Shuffle, PrimExpr, ShuffleNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode);
};
// Reduce operator
/*!
* \brief A commutative reducer node to represent a commutative
* binary operator with identity element
*/
class CommReducerNode : public ffi::Object {
public:
/*! \brief The left argument of reducer */
ffi::Array<Var> lhs;
/*! \brief The right argument of reducer */
ffi::Array<Var> rhs;
/*! \brief The result of reducer */
ffi::Array<PrimExpr> result;
/*!
* \brief The identity element of reducer, which leaves other
* elements unchanged when combined with it, with respect to
* the binary operation of this reducer uses.
*/
ffi::Array<PrimExpr> identity_element;
/*! \brief Function call operator to combine a and b */
ffi::Array<PrimExpr> operator()(ffi::Array<PrimExpr> a, ffi::Array<PrimExpr> b) const;
/*!
* \brief Span that points to the original source code.
* Reserved debug information.
*/
mutable Span span;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<CommReducerNode>()
.def_ro("lhs", &CommReducerNode::lhs, refl::AttachFieldFlag::SEqHashDefRecursive())
.def_ro("rhs", &CommReducerNode::rhs, refl::AttachFieldFlag::SEqHashDefRecursive())
.def_ro("result", &CommReducerNode::result)
.def_ro("identity_element", &CommReducerNode::identity_element)
.def_ro("span", &CommReducerNode::span, refl::AttachFieldFlag::SEqHashIgnore());
}
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.CommReducer", CommReducerNode, ffi::Object);
};
/*!
* \brief Managed reference to CommReducerNode
* \sa CommReducerNode
*/
class CommReducer : public ffi::ObjectRef {
public:
TVM_DLL CommReducer(ffi::Array<Var> lhs, ffi::Array<Var> rhs, ffi::Array<PrimExpr> result,
ffi::Array<PrimExpr> identity_element, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CommReducer, ffi::ObjectRef, CommReducerNode);
};
/*! \brief Reduction operator */
class ReduceNode : public PrimExprNode {
public:
/*! \brief The commutative combiner */
CommReducer combiner;
/*! \brief The source operand */
ffi::Array<PrimExpr> source;
/*! \brief The init operand */
ffi::Array<PrimExpr> init;
/*! \brief The reduction axis */
ffi::Array<IterVar> axis;
/*!
* \brief Predicate on the reduction
* Only add the body to reduction if condition is true.
*/
PrimExpr condition;
/*! \brief the index of this reduce node */
int value_index;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ReduceNode>()
.def_ro("combiner", &ReduceNode::combiner)
.def_ro("source", &ReduceNode::source)
.def_ro("init", &ReduceNode::init)
.def_ro("axis", &ReduceNode::axis)
.def_ro("condition", &ReduceNode::condition)
.def_ro("value_index", &ReduceNode::value_index);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Reduce", ReduceNode, PrimExprNode);
};
/*!
* \brief Managed reference to ReduceNode
* \sa ReduceNode
*/
class Reduce : public PrimExpr {
public:
TVM_DLL Reduce(CommReducer combiner, ffi::Array<PrimExpr> src, ffi::Array<IterVar> rdom,
PrimExpr condition, int value_index, ffi::Array<PrimExpr> init,
Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Reduce, PrimExpr, ReduceNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode);
};
/*
* \brief Template function to convert Map to unordered_map
* Sometimes useful for API gluing when internal uses unordered_map
* \param dmap The container map
* \return The corresponding unordered_map.
* \tparam K the key of the Map.
* \tparam V the value of the Map.
*/
template <typename K, typename V>
inline std::unordered_map<K, V> as_unordered_map(const ffi::Map<K, V>& dmap) {
std::unordered_map<K, V> ret;
for (auto kv : dmap) {
ret[kv.first] = kv.second;
}
return ret;
}
} // namespace tirx
namespace ffi {
template <>
inline constexpr bool use_default_type_traits_v<tvm::tirx::StringImm> = false;
template <>
struct TypeTraits<tvm::tirx::StringImm>
: public ObjectRefWithFallbackTraitsBase<tvm::tirx::StringImm, ffi::String> {
TVM_FFI_INLINE static tvm::tirx::StringImm ConvertFallbackValue(ffi::String value) {
return tvm::tirx::StringImm(value);
}
};
} // namespace ffi
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::tirx::IterVar> : public ::tvm::ffi::ObjectPtrHash {};
} // namespace std
#endif // TVM_TIR_EXPR_H_