| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file src/tir/ir/py_functor.cc |
| * \brief The python interface of ExprVisitor/ExprMutator, StmtVisitor/StmtMutator, |
| * StmtExprVisitor/StmtExprMutator. |
| */ |
| |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/tir/expr_functor.h> |
| #include <tvm/tir/stmt_functor.h> |
| |
| namespace tvm { |
| namespace tir { |
| |
| // ================================================ |
| // Helper Macros |
| // ================================================ |
| #define PY_EXPR_VISITOR_DISPATCH(OP, PY_FUNC) \ |
| void VisitExpr_(const OP* op) override { \ |
| if (PY_FUNC != nullptr) { \ |
| PY_FUNC(op); \ |
| } else { \ |
| StmtExprVisitor::VisitExpr_(op); \ |
| } \ |
| } |
| |
| #define IR_EXPR_VISITOR_DEFAULT_DISPATCH(OP) \ |
| vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \ |
| self->StmtExprVisitor::VisitExpr_(static_cast<const OP*>(n.get())); \ |
| }); |
| |
| #define PY_STMT_VISITOR_DISPATCH(OP, PY_FUNC) \ |
| void VisitStmt_(const OP* op) override { \ |
| if (PY_FUNC != nullptr) { \ |
| PY_FUNC(op); \ |
| } else { \ |
| StmtExprVisitor::VisitStmt_(op); \ |
| } \ |
| } |
| |
| #define PY_STMT_VISITOR_DEFAULT_DISPATCH(OP) \ |
| vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \ |
| self->StmtExprVisitor::VisitStmt_(static_cast<const OP*>(n.get())); \ |
| }); |
| |
| #define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \ |
| PrimExpr VisitExpr_(const OP* op) override { \ |
| if (PY_FUNC != nullptr) { \ |
| return PY_FUNC(op).cast<PrimExpr>(); \ |
| } else { \ |
| return StmtExprMutator::VisitExpr_(op); \ |
| } \ |
| } |
| |
| #define PY_EXPR_MUTATOR_DEFAULT_DISPATCH(OP) \ |
| vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \ |
| return self->StmtExprMutator::VisitExpr_(static_cast<const OP*>(n.get())); \ |
| }); |
| |
| #define PY_STMT_MUTATOR_DISPATCH(OP, PY_FUNC) \ |
| Stmt VisitStmt_(const OP* op) override { \ |
| if (PY_FUNC != nullptr) { \ |
| return PY_FUNC(op).cast<Stmt>(); \ |
| } else { \ |
| return StmtExprMutator::VisitStmt_(op); \ |
| } \ |
| } |
| |
| #define PY_STMT_MUTATOR_DEFAULT_DISPATCH(OP) \ |
| vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self) { \ |
| return self->StmtExprMutator::VisitStmt_(static_cast<const OP*>(n.get())); \ |
| }); |
| |
| /*! \brief The python interface of StmtExprVisitor. */ |
| class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { |
| private: |
| using TSelf = PyStmtExprVisitorNode; |
| using FExprType = tvm::NodeFunctor<void(const ObjectRef& n, TSelf* self)>; |
| using FStmtType = tvm::NodeFunctor<void(const ObjectRef& n, TSelf* self)>; |
| |
| public: |
| // Expression functions |
| /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ |
| ffi::Function f_visit_expr{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ |
| ffi::Function f_visit_var{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const SizeVarNode* op)` function. */ |
| ffi::Function f_visit_size_var{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const BufferLoadNode* op)` function. */ |
| ffi::Function f_visit_buffer_load{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const ProducerLoadNode* op)` function. */ |
| ffi::Function f_visit_producer_load{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const LetNode* op)` function. */ |
| ffi::Function f_visit_let{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ |
| ffi::Function f_visit_call{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const AddNode* op)` function. */ |
| ffi::Function f_visit_add{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const SubNode* op)` function. */ |
| ffi::Function f_visit_sub{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const MulNode* op)` function. */ |
| ffi::Function f_visit_mul{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const DivNode* op)` function. */ |
| ffi::Function f_visit_div{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const ModNode* op)` function. */ |
| ffi::Function f_visit_mod{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const FloorDivNode* op)` function. */ |
| ffi::Function f_visit_floor_div{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const FloorModNode* op)` function. */ |
| ffi::Function f_visit_floor_mod{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const MinNode* op)` function. */ |
| ffi::Function f_visit_min{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const MaxNode* op)` function. */ |
| ffi::Function f_visit_max{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const EQNode* op)` function. */ |
| ffi::Function f_visit_eq{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const NENode* op)` function. */ |
| ffi::Function f_visit_ne{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const LTNode* op)` function. */ |
| ffi::Function f_visit_lt{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const LENode* op)` function. */ |
| ffi::Function f_visit_le{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const GTNode* op)` function. */ |
| ffi::Function f_visit_gt{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const GENode* op)` function. */ |
| ffi::Function f_visit_ge{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const AndNode* op)` function. */ |
| ffi::Function f_visit_and{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const OrNode* op)` function. */ |
| ffi::Function f_visit_or{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const ReduceNode* op)` function. */ |
| ffi::Function f_visit_reduce{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const CastNode* op)` function. */ |
| ffi::Function f_visit_cast{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const NotNode* op)` function. */ |
| ffi::Function f_visit_not{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const SelectNode* op)` function. */ |
| ffi::Function f_visit_select{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const RampNode* op)` function. */ |
| ffi::Function f_visit_ramp{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const BroadcastNode* op)` function. */ |
| ffi::Function f_visit_broadcast{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const ShuffleNode* op)` function. */ |
| ffi::Function f_visit_shuffle{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const IntImmNode* op)` function. */ |
| ffi::Function f_visit_int_imm{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const FloatImmNode* op)` function. */ |
| ffi::Function f_visit_float_imm{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` function. */ |
| ffi::Function f_visit_string_imm{nullptr}; |
| |
| // Statement functions |
| /*! \brief The packed function to the `VisitStmt(const Stmt& stmt)` function. */ |
| ffi::Function f_visit_stmt{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const LetStmtNode* op)` function. */ |
| ffi::Function f_visit_attr_stmt{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const IfThenElseNode* op)` function. */ |
| ffi::Function f_visit_if_then_else{nullptr}; // NOLINT(readability/braces) |
| /*! \brief The packed function to the `VisitStmt_(const ForNode* op)` function. */ |
| ffi::Function f_visit_let_stmt{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)` function. */ |
| ffi::Function f_visit_for{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const WhileNode* op)` function. */ |
| ffi::Function f_visit_while{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const AllocateNode* op)` function. */ |
| ffi::Function f_visit_allocate{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const AllocateConstNode* op)` function. */ |
| ffi::Function f_visit_allocate_const{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const DeclBufferNode* op)` function. */ |
| ffi::Function f_visit_decl_buffer{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const BufferStoreNode* op)` function. */ |
| ffi::Function f_visit_buffer_store{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const BufferRealizeNode* op)` function. */ |
| ffi::Function f_visit_buffer_realize{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const AssertStmtNode* op)` function. */ |
| ffi::Function f_visit_assert_stmt{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const SeqStmtNode* op)` function. */ |
| ffi::Function f_visit_seq_stmt{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const EvaluateNode* op)` function. */ |
| ffi::Function f_visit_evaluate{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const BlockNode* op)` function. */ |
| ffi::Function f_visit_block{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const BlockRealizeNode* op)` function. */ |
| ffi::Function f_visit_block_realize{nullptr}; |
| |
| using StmtExprVisitor::VisitExpr; |
| using StmtExprVisitor::VisitStmt; |
| |
| void DefaultVisitExpr(const PrimExpr& expr) { |
| static FExprType vtable = InitExprVTable(); |
| vtable(expr, this); |
| } |
| |
| void DefaultVisitStmt(const Stmt& stmt) { |
| static FStmtType vtable = InitStmtVTable(); |
| vtable(stmt, this); |
| } |
| |
| static void RegisterReflection() { |
| // No fields to register as they are not visited |
| } |
| |
| static constexpr const bool _type_mutable = true; |
| TVM_FFI_DECLARE_OBJECT_INFO("tir.PyStmtExprVisitor", PyStmtExprVisitorNode, Object); |
| |
| private: |
| // Statement functions |
| PY_STMT_VISITOR_DISPATCH(LetStmtNode, f_visit_let_stmt); |
| PY_STMT_VISITOR_DISPATCH(AttrStmtNode, f_visit_attr_stmt); |
| PY_STMT_VISITOR_DISPATCH(IfThenElseNode, f_visit_if_then_else); |
| PY_STMT_VISITOR_DISPATCH(ForNode, f_visit_for); |
| PY_STMT_VISITOR_DISPATCH(WhileNode, f_visit_while); |
| PY_STMT_VISITOR_DISPATCH(AllocateNode, f_visit_allocate); |
| PY_STMT_VISITOR_DISPATCH(AllocateConstNode, f_visit_allocate_const); |
| PY_STMT_VISITOR_DISPATCH(DeclBufferNode, f_visit_decl_buffer); |
| PY_STMT_VISITOR_DISPATCH(BufferStoreNode, f_visit_buffer_store); |
| PY_STMT_VISITOR_DISPATCH(BufferRealizeNode, f_visit_buffer_realize); |
| PY_STMT_VISITOR_DISPATCH(AssertStmtNode, f_visit_assert_stmt); |
| PY_STMT_VISITOR_DISPATCH(SeqStmtNode, f_visit_seq_stmt); |
| PY_STMT_VISITOR_DISPATCH(EvaluateNode, f_visit_evaluate); |
| PY_STMT_VISITOR_DISPATCH(BlockNode, f_visit_block); |
| PY_STMT_VISITOR_DISPATCH(BlockRealizeNode, f_visit_block_realize); |
| // Expression functions |
| PY_EXPR_VISITOR_DISPATCH(VarNode, f_visit_var); |
| PY_EXPR_VISITOR_DISPATCH(SizeVarNode, f_visit_size_var); |
| PY_EXPR_VISITOR_DISPATCH(BufferLoadNode, f_visit_buffer_load); |
| PY_EXPR_VISITOR_DISPATCH(ProducerLoadNode, f_visit_producer_load); |
| PY_EXPR_VISITOR_DISPATCH(LetNode, f_visit_let); |
| PY_EXPR_VISITOR_DISPATCH(CallNode, f_visit_call); |
| PY_EXPR_VISITOR_DISPATCH(AddNode, f_visit_add); |
| PY_EXPR_VISITOR_DISPATCH(SubNode, f_visit_sub); |
| PY_EXPR_VISITOR_DISPATCH(MulNode, f_visit_mul); |
| PY_EXPR_VISITOR_DISPATCH(DivNode, f_visit_div); |
| PY_EXPR_VISITOR_DISPATCH(ModNode, f_visit_mod); |
| PY_EXPR_VISITOR_DISPATCH(FloorDivNode, f_visit_floor_div); |
| PY_EXPR_VISITOR_DISPATCH(FloorModNode, f_visit_floor_mod); |
| PY_EXPR_VISITOR_DISPATCH(MinNode, f_visit_min); |
| PY_EXPR_VISITOR_DISPATCH(MaxNode, f_visit_max); |
| PY_EXPR_VISITOR_DISPATCH(EQNode, f_visit_eq); |
| PY_EXPR_VISITOR_DISPATCH(NENode, f_visit_ne); |
| PY_EXPR_VISITOR_DISPATCH(LTNode, f_visit_lt); |
| PY_EXPR_VISITOR_DISPATCH(LENode, f_visit_le); |
| PY_EXPR_VISITOR_DISPATCH(GTNode, f_visit_gt); |
| PY_EXPR_VISITOR_DISPATCH(GENode, f_visit_ge); |
| PY_EXPR_VISITOR_DISPATCH(AndNode, f_visit_and); |
| PY_EXPR_VISITOR_DISPATCH(OrNode, f_visit_or); |
| PY_EXPR_VISITOR_DISPATCH(ReduceNode, f_visit_reduce); |
| PY_EXPR_VISITOR_DISPATCH(CastNode, f_visit_cast); |
| PY_EXPR_VISITOR_DISPATCH(NotNode, f_visit_not); |
| PY_EXPR_VISITOR_DISPATCH(SelectNode, f_visit_select); |
| PY_EXPR_VISITOR_DISPATCH(RampNode, f_visit_ramp); |
| PY_EXPR_VISITOR_DISPATCH(BroadcastNode, f_visit_broadcast); |
| PY_EXPR_VISITOR_DISPATCH(ShuffleNode, f_visit_shuffle); |
| PY_EXPR_VISITOR_DISPATCH(IntImmNode, f_visit_int_imm); |
| PY_EXPR_VISITOR_DISPATCH(FloatImmNode, f_visit_float_imm); |
| PY_EXPR_VISITOR_DISPATCH(StringImmNode, f_visit_string_imm); |
| |
| private: |
| static FExprType InitExprVTable() { |
| FExprType vtable; |
| // Set dispatch |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(VarNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(SizeVarNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(BufferLoadNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(ProducerLoadNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(LetNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(CallNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(AddNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(SubNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(MulNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(DivNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(ModNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(FloorDivNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(FloorModNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(MinNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(MaxNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(EQNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(NENode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(LTNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(LENode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(GTNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(GENode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(AndNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(OrNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(ReduceNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(CastNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(NotNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(SelectNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(RampNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(ShuffleNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(BroadcastNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(IntImmNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(FloatImmNode); |
| IR_EXPR_VISITOR_DEFAULT_DISPATCH(StringImmNode); |
| vtable.Finalize(); |
| return vtable; |
| } |
| |
| static FStmtType InitStmtVTable() { |
| FStmtType vtable; |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(LetStmtNode); |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(AttrStmtNode); |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(IfThenElseNode); |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(ForNode); |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(WhileNode); |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(AllocateNode); |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(AllocateConstNode); |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(DeclBufferNode); |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(BufferStoreNode); |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(BufferRealizeNode); |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(AssertStmtNode); |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(SeqStmtNode); |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(EvaluateNode); |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(BlockNode); |
| PY_STMT_VISITOR_DEFAULT_DISPATCH(BlockRealizeNode); |
| vtable.Finalize(); |
| return vtable; |
| } |
| }; |
| |
| /*! |
| * \brief Managed reference to PyStmtExprVisitorNode. |
| * \sa PyStmtExprVisitorNode |
| */ |
| class PyStmtExprVisitor : public ObjectRef { |
| public: |
| explicit PyStmtExprVisitor(ObjectPtr<PyStmtExprVisitorNode> data) : ObjectRef(data) { |
| TVM_FFI_ICHECK(data != nullptr); |
| } |
| TVM_DLL static PyStmtExprVisitor MakePyStmtExprVisitor(ffi::Function f_visit_stmt, // |
| ffi::Function f_visit_expr, // |
| ffi::Function f_visit_let_stmt, // |
| ffi::Function f_visit_attr_stmt, // |
| ffi::Function f_visit_if_then_else, // |
| ffi::Function f_visit_for, // |
| ffi::Function f_visit_while, // |
| ffi::Function f_visit_allocate, // |
| ffi::Function f_visit_allocate_const, // |
| ffi::Function f_visit_decl_buffer, // |
| ffi::Function f_visit_buffer_store, // |
| ffi::Function f_visit_buffer_realize, // |
| ffi::Function f_visit_assert_stmt, // |
| ffi::Function f_visit_seq_stmt, // |
| ffi::Function f_visit_evaluate, // |
| ffi::Function f_visit_block, // |
| ffi::Function f_visit_block_realize, // |
| ffi::Function f_visit_var, // |
| ffi::Function f_visit_size_var, // |
| ffi::Function f_visit_buffer_load, // |
| ffi::Function f_visit_producer_load, // |
| ffi::Function f_visit_let, // |
| ffi::Function f_visit_call, // |
| ffi::Function f_visit_add, // |
| ffi::Function f_visit_sub, // |
| ffi::Function f_visit_mul, // |
| ffi::Function f_visit_div, // |
| ffi::Function f_visit_mod, // |
| ffi::Function f_visit_floor_div, // |
| ffi::Function f_visit_floor_mod, // |
| ffi::Function f_visit_min, // |
| ffi::Function f_visit_max, // |
| ffi::Function f_visit_eq, // |
| ffi::Function f_visit_ne, // |
| ffi::Function f_visit_lt, // |
| ffi::Function f_visit_le, // |
| ffi::Function f_visit_gt, // |
| ffi::Function f_visit_ge, // |
| ffi::Function f_visit_and, // |
| ffi::Function f_visit_or, // |
| ffi::Function f_visit_reduce, // |
| ffi::Function f_visit_cast, // |
| ffi::Function f_visit_not, // |
| ffi::Function f_visit_select, // |
| ffi::Function f_visit_ramp, // |
| ffi::Function f_visit_broadcast, // |
| ffi::Function f_visit_shuffle, // |
| ffi::Function f_visit_int_imm, // |
| ffi::Function f_visit_float_imm, // |
| ffi::Function f_visit_string_imm) { |
| ObjectPtr<PyStmtExprVisitorNode> n = ffi::make_object<PyStmtExprVisitorNode>(); |
| n->f_visit_stmt = std::move(f_visit_stmt); |
| n->f_visit_expr = std::move(f_visit_expr); |
| // Set statement functions |
| n->f_visit_let_stmt = std::move(f_visit_let_stmt); |
| n->f_visit_attr_stmt = std::move(f_visit_attr_stmt); |
| n->f_visit_if_then_else = std::move(f_visit_if_then_else); |
| n->f_visit_for = std::move(f_visit_for); |
| n->f_visit_while = std::move(f_visit_while); |
| n->f_visit_allocate = std::move(f_visit_allocate); |
| n->f_visit_allocate_const = std::move(f_visit_allocate_const); |
| n->f_visit_decl_buffer = std::move(f_visit_decl_buffer); |
| n->f_visit_buffer_store = std::move(f_visit_buffer_store); |
| n->f_visit_buffer_realize = std::move(f_visit_buffer_realize); |
| n->f_visit_assert_stmt = std::move(f_visit_assert_stmt); |
| n->f_visit_seq_stmt = std::move(f_visit_seq_stmt); |
| n->f_visit_evaluate = std::move(f_visit_evaluate); |
| n->f_visit_block = std::move(f_visit_block); |
| n->f_visit_block_realize = std::move(f_visit_block_realize); |
| // Set expression functions |
| n->f_visit_var = std::move(f_visit_var); |
| n->f_visit_size_var = std::move(f_visit_size_var); |
| n->f_visit_buffer_load = std::move(f_visit_buffer_load); |
| n->f_visit_producer_load = std::move(f_visit_producer_load); |
| n->f_visit_let = std::move(f_visit_let); |
| n->f_visit_call = std::move(f_visit_call); |
| n->f_visit_add = std::move(f_visit_add); |
| n->f_visit_sub = std::move(f_visit_sub); |
| n->f_visit_mul = std::move(f_visit_mul); |
| n->f_visit_div = std::move(f_visit_div); |
| n->f_visit_mod = std::move(f_visit_mod); |
| n->f_visit_floor_div = std::move(f_visit_floor_div); |
| n->f_visit_floor_mod = std::move(f_visit_floor_mod); |
| n->f_visit_min = std::move(f_visit_min); |
| n->f_visit_max = std::move(f_visit_max); |
| n->f_visit_eq = std::move(f_visit_eq); |
| n->f_visit_ne = std::move(f_visit_ne); |
| n->f_visit_lt = std::move(f_visit_lt); |
| n->f_visit_le = std::move(f_visit_le); |
| n->f_visit_gt = std::move(f_visit_gt); |
| n->f_visit_ge = std::move(f_visit_ge); |
| n->f_visit_and = std::move(f_visit_and); |
| n->f_visit_or = std::move(f_visit_or); |
| n->f_visit_reduce = std::move(f_visit_reduce); |
| n->f_visit_cast = std::move(f_visit_cast); |
| n->f_visit_not = std::move(f_visit_not); |
| n->f_visit_select = std::move(f_visit_select); |
| n->f_visit_ramp = std::move(f_visit_ramp); |
| n->f_visit_broadcast = std::move(f_visit_broadcast); |
| n->f_visit_shuffle = std::move(f_visit_shuffle); |
| n->f_visit_int_imm = std::move(f_visit_int_imm); |
| n->f_visit_float_imm = std::move(f_visit_float_imm); |
| n->f_visit_string_imm = std::move(f_visit_string_imm); |
| return PyStmtExprVisitor(n); |
| } |
| |
| TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PyStmtExprVisitor, ObjectRef, |
| PyStmtExprVisitorNode); |
| }; |
| |
| /*! \brief The python interface of StmtExprMutator. */ |
| class PyStmtExprMutatorNode : public Object, public StmtExprMutator { |
| private: |
| using TSelf = PyStmtExprMutatorNode; |
| using FExprType = tvm::NodeFunctor<PrimExpr(const ObjectRef& n, TSelf* self)>; |
| using FStmtType = tvm::NodeFunctor<Stmt(const ObjectRef& n, TSelf* self)>; |
| |
| public: |
| // Expression functions |
| /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ |
| ffi::Function f_visit_expr{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ |
| ffi::Function f_visit_var{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const SizeVarNode* op)` function. */ |
| ffi::Function f_visit_size_var{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const BufferLoadNode* op)` function. */ |
| ffi::Function f_visit_buffer_load{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const ProducerLoadNode* op)` function. */ |
| ffi::Function f_visit_producer_load{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const LetNode* op)` function. */ |
| ffi::Function f_visit_let{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ |
| ffi::Function f_visit_call{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const AddNode* op)` function. */ |
| ffi::Function f_visit_add{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const SubNode* op)` function. */ |
| ffi::Function f_visit_sub{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const MulNode* op)` function. */ |
| ffi::Function f_visit_mul{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const DivNode* op)` function. */ |
| ffi::Function f_visit_div{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const ModNode* op)` function. */ |
| ffi::Function f_visit_mod{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const FloorDivNode* op)` function. */ |
| ffi::Function f_visit_floor_div{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const FloorModNode* op)` function. */ |
| ffi::Function f_visit_floor_mod{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const MinNode* op)` function. */ |
| ffi::Function f_visit_min{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const MaxNode* op)` function. */ |
| ffi::Function f_visit_max{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const EQNode* op)` function. */ |
| ffi::Function f_visit_eq{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const NENode* op)` function. */ |
| ffi::Function f_visit_ne{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const LTNode* op)` function. */ |
| ffi::Function f_visit_lt{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const LENode* op)` function. */ |
| ffi::Function f_visit_le{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const GTNode* op)` function. */ |
| ffi::Function f_visit_gt{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const GENode* op)` function. */ |
| ffi::Function f_visit_ge{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const AndNode* op)` function. */ |
| ffi::Function f_visit_and{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const OrNode* op)` function. */ |
| ffi::Function f_visit_or{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const ReduceNode* op)` function. */ |
| ffi::Function f_visit_reduce{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const CastNode* op)` function. */ |
| ffi::Function f_visit_cast{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const NotNode* op)` function. */ |
| ffi::Function f_visit_not{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const SelectNode* op)` function. */ |
| ffi::Function f_visit_select{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const RampNode* op)` function. */ |
| ffi::Function f_visit_ramp{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const BroadcastNode* op)` function. */ |
| ffi::Function f_visit_broadcast{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const ShuffleNode* op)` function. */ |
| ffi::Function f_visit_shuffle{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const IntImmNode* op)` function. */ |
| ffi::Function f_visit_int_imm{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const FloatImmNode* op)` function. */ |
| ffi::Function f_visit_float_imm{nullptr}; |
| /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` function. */ |
| ffi::Function f_visit_string_imm{nullptr}; |
| |
| // Statement functions |
| /*! \brief The packed function to the `VisitStmt(const Stmt& stmt)` function. */ |
| ffi::Function f_visit_stmt{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const LetStmtNode* op)` function. */ |
| ffi::Function f_visit_let_stmt{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const AttrStmtNode* op)` function. */ |
| ffi::Function f_visit_attr_stmt{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const IfThenElseNode* op)` function. */ |
| ffi::Function f_visit_if_then_else{nullptr}; // NOLINT(readability/braces) |
| /*! \brief The packed function to the `VisitStmt_(const ForNode* op)` function. */ |
| ffi::Function f_visit_for{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const WhileNode* op)` function. */ |
| ffi::Function f_visit_while{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const AllocateNode* op)` function. */ |
| ffi::Function f_visit_allocate{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const AllocateConstNode* op)` function. */ |
| ffi::Function f_visit_allocate_const{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const DeclBufferNode* op)` function. */ |
| ffi::Function f_visit_decl_buffer{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const BufferStoreNode* op)` function. */ |
| ffi::Function f_visit_buffer_store{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const BufferRealizeNode* op)` function. */ |
| ffi::Function f_visit_buffer_realize{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const AssertStmtNode* op)` function. */ |
| ffi::Function f_visit_assert_stmt{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const SeqStmtNode* op)` function. */ |
| ffi::Function f_visit_seq_stmt{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const EvaluateNode* op)` function. */ |
| ffi::Function f_visit_evaluate{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const BlockNode* op)` function. */ |
| ffi::Function f_visit_block{nullptr}; |
| /*! \brief The packed function to the `VisitStmt_(const BlockRealizeNode* op)` function. */ |
| ffi::Function f_visit_block_realize{nullptr}; |
| |
| using StmtExprMutator::VisitExpr; |
| using StmtExprMutator::VisitStmt; |
| |
| void DefaultVisitExpr(const PrimExpr& expr) { |
| static FExprType vtable = InitExprVTable(); |
| vtable(expr, this); |
| } |
| |
| void DefaultVisitStmt(const Stmt& stmt) { |
| static FStmtType vtable = InitStmtVTable(); |
| vtable(stmt, this); |
| } |
| |
| static void RegisterReflection() { |
| // No fields to register as they are not visited |
| } |
| |
| static constexpr const bool _type_mutable = true; |
| TVM_FFI_DECLARE_OBJECT_INFO("tir.PyStmtExprMutator", PyStmtExprMutatorNode, Object); |
| |
| private: |
| // Statement functions |
| PY_STMT_MUTATOR_DISPATCH(LetStmtNode, f_visit_let_stmt); |
| PY_STMT_MUTATOR_DISPATCH(AttrStmtNode, f_visit_attr_stmt); |
| PY_STMT_MUTATOR_DISPATCH(IfThenElseNode, f_visit_if_then_else); |
| PY_STMT_MUTATOR_DISPATCH(ForNode, f_visit_for); |
| PY_STMT_MUTATOR_DISPATCH(WhileNode, f_visit_while); |
| PY_STMT_MUTATOR_DISPATCH(AllocateNode, f_visit_allocate); |
| PY_STMT_MUTATOR_DISPATCH(AllocateConstNode, f_visit_allocate_const); |
| PY_STMT_MUTATOR_DISPATCH(DeclBufferNode, f_visit_decl_buffer); |
| PY_STMT_MUTATOR_DISPATCH(BufferStoreNode, f_visit_buffer_store); |
| PY_STMT_MUTATOR_DISPATCH(BufferRealizeNode, f_visit_buffer_realize); |
| PY_STMT_MUTATOR_DISPATCH(AssertStmtNode, f_visit_assert_stmt); |
| PY_STMT_MUTATOR_DISPATCH(SeqStmtNode, f_visit_seq_stmt); |
| PY_STMT_MUTATOR_DISPATCH(EvaluateNode, f_visit_evaluate); |
| PY_STMT_MUTATOR_DISPATCH(BlockNode, f_visit_block); |
| PY_STMT_MUTATOR_DISPATCH(BlockRealizeNode, f_visit_block_realize); |
| // Expression functions |
| PY_EXPR_MUTATOR_DISPATCH(VarNode, f_visit_var); |
| PY_EXPR_MUTATOR_DISPATCH(SizeVarNode, f_visit_size_var); |
| PY_EXPR_MUTATOR_DISPATCH(BufferLoadNode, f_visit_buffer_load); |
| PY_EXPR_MUTATOR_DISPATCH(ProducerLoadNode, f_visit_producer_load); |
| PY_EXPR_MUTATOR_DISPATCH(LetNode, f_visit_let); |
| PY_EXPR_MUTATOR_DISPATCH(CallNode, f_visit_call); |
| PY_EXPR_MUTATOR_DISPATCH(AddNode, f_visit_add); |
| PY_EXPR_MUTATOR_DISPATCH(SubNode, f_visit_sub); |
| PY_EXPR_MUTATOR_DISPATCH(MulNode, f_visit_mul); |
| PY_EXPR_MUTATOR_DISPATCH(DivNode, f_visit_div); |
| PY_EXPR_MUTATOR_DISPATCH(ModNode, f_visit_mod); |
| PY_EXPR_MUTATOR_DISPATCH(FloorDivNode, f_visit_floor_div); |
| PY_EXPR_MUTATOR_DISPATCH(FloorModNode, f_visit_floor_mod); |
| PY_EXPR_MUTATOR_DISPATCH(MinNode, f_visit_min); |
| PY_EXPR_MUTATOR_DISPATCH(MaxNode, f_visit_max); |
| PY_EXPR_MUTATOR_DISPATCH(EQNode, f_visit_eq); |
| PY_EXPR_MUTATOR_DISPATCH(NENode, f_visit_ne); |
| PY_EXPR_MUTATOR_DISPATCH(LTNode, f_visit_lt); |
| PY_EXPR_MUTATOR_DISPATCH(LENode, f_visit_le); |
| PY_EXPR_MUTATOR_DISPATCH(GTNode, f_visit_gt); |
| PY_EXPR_MUTATOR_DISPATCH(GENode, f_visit_ge); |
| PY_EXPR_MUTATOR_DISPATCH(AndNode, f_visit_and); |
| PY_EXPR_MUTATOR_DISPATCH(OrNode, f_visit_or); |
| PY_EXPR_MUTATOR_DISPATCH(ReduceNode, f_visit_reduce); |
| PY_EXPR_MUTATOR_DISPATCH(CastNode, f_visit_cast); |
| PY_EXPR_MUTATOR_DISPATCH(NotNode, f_visit_not); |
| PY_EXPR_MUTATOR_DISPATCH(SelectNode, f_visit_select); |
| PY_EXPR_MUTATOR_DISPATCH(RampNode, f_visit_ramp); |
| PY_EXPR_MUTATOR_DISPATCH(BroadcastNode, f_visit_broadcast); |
| PY_EXPR_MUTATOR_DISPATCH(ShuffleNode, f_visit_shuffle); |
| PY_EXPR_MUTATOR_DISPATCH(IntImmNode, f_visit_int_imm); |
| PY_EXPR_MUTATOR_DISPATCH(FloatImmNode, f_visit_float_imm); |
| PY_EXPR_MUTATOR_DISPATCH(StringImmNode, f_visit_string_imm); |
| |
| private: |
| static FExprType InitExprVTable() { |
| FExprType vtable; |
| // Set dispatch |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(VarNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(SizeVarNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(BufferLoadNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ProducerLoadNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(LetNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(CallNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(AddNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(SubNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(MulNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(DivNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ModNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(FloorDivNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(FloorModNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(MinNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(MaxNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(EQNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(NENode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(LTNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(LENode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(GTNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(GENode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(AndNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(OrNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ReduceNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(CastNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(NotNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(SelectNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(RampNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(ShuffleNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(BroadcastNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(IntImmNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(FloatImmNode); |
| PY_EXPR_MUTATOR_DEFAULT_DISPATCH(StringImmNode); |
| vtable.Finalize(); |
| return vtable; |
| } |
| |
| static FStmtType InitStmtVTable() { |
| FStmtType vtable; |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(LetStmtNode); |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(AttrStmtNode); |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(IfThenElseNode); |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(ForNode); |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(WhileNode); |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(AllocateNode); |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(AllocateConstNode); |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(DeclBufferNode); |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(BufferStoreNode); |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(BufferRealizeNode); |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(AssertStmtNode); |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(SeqStmtNode); |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(EvaluateNode); |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(BlockNode); |
| PY_STMT_MUTATOR_DEFAULT_DISPATCH(BlockRealizeNode); |
| vtable.Finalize(); |
| return vtable; |
| } |
| }; |
| |
| /*! \brief Managed reference to PyStmtExprMutatorNode. */ |
| class PyStmtExprMutator : public ObjectRef { |
| public: |
| explicit PyStmtExprMutator(ObjectPtr<PyStmtExprMutatorNode> data) : ObjectRef(data) { |
| TVM_FFI_ICHECK(data != nullptr); |
| } |
| /*! |
| * \brief Create a PyStmtExprMutator with customized methods on the python-side. |
| * \return The PyStmtExprMutator created. |
| */ |
| TVM_DLL static PyStmtExprMutator MakePyStmtExprMutator(ffi::Function f_visit_stmt, // |
| ffi::Function f_visit_expr, // |
| ffi::Function f_visit_let_stmt, // |
| ffi::Function f_visit_attr_stmt, // |
| ffi::Function f_visit_if_then_else, // |
| ffi::Function f_visit_for, // |
| ffi::Function f_visit_while, // |
| ffi::Function f_visit_allocate, // |
| ffi::Function f_visit_allocate_const, // |
| ffi::Function f_visit_decl_buffer, // |
| ffi::Function f_visit_buffer_store, // |
| ffi::Function f_visit_buffer_realize, // |
| ffi::Function f_visit_assert_stmt, // |
| ffi::Function f_visit_seq_stmt, // |
| ffi::Function f_visit_evaluate, // |
| ffi::Function f_visit_block, // |
| ffi::Function f_visit_block_realize, // |
| ffi::Function f_visit_var, // |
| ffi::Function f_visit_size_var, // |
| ffi::Function f_visit_buffer_load, // |
| ffi::Function f_visit_producer_load, // |
| ffi::Function f_visit_let, // |
| ffi::Function f_visit_call, // |
| ffi::Function f_visit_add, // |
| ffi::Function f_visit_sub, // |
| ffi::Function f_visit_mul, // |
| ffi::Function f_visit_div, // |
| ffi::Function f_visit_mod, // |
| ffi::Function f_visit_floor_div, // |
| ffi::Function f_visit_floor_mod, // |
| ffi::Function f_visit_min, // |
| ffi::Function f_visit_max, // |
| ffi::Function f_visit_eq, // |
| ffi::Function f_visit_ne, // |
| ffi::Function f_visit_lt, // |
| ffi::Function f_visit_le, // |
| ffi::Function f_visit_gt, // |
| ffi::Function f_visit_ge, // |
| ffi::Function f_visit_and, // |
| ffi::Function f_visit_or, // |
| ffi::Function f_visit_reduce, // |
| ffi::Function f_visit_cast, // |
| ffi::Function f_visit_not, // |
| ffi::Function f_visit_select, // |
| ffi::Function f_visit_ramp, // |
| ffi::Function f_visit_broadcast, // |
| ffi::Function f_visit_shuffle, // |
| ffi::Function f_visit_int_imm, // |
| ffi::Function f_visit_float_imm, // |
| ffi::Function f_visit_string_imm) { |
| ObjectPtr<PyStmtExprMutatorNode> n = ffi::make_object<PyStmtExprMutatorNode>(); |
| n->f_visit_stmt = std::move(f_visit_stmt); |
| n->f_visit_expr = std::move(f_visit_expr); |
| // Statement functions |
| n->f_visit_let_stmt = std::move(f_visit_let_stmt); |
| n->f_visit_attr_stmt = std::move(f_visit_attr_stmt); |
| n->f_visit_if_then_else = std::move(f_visit_if_then_else); |
| n->f_visit_for = std::move(f_visit_for); |
| n->f_visit_while = std::move(f_visit_while); |
| n->f_visit_allocate = std::move(f_visit_allocate); |
| n->f_visit_allocate_const = std::move(f_visit_allocate_const); |
| n->f_visit_decl_buffer = std::move(f_visit_decl_buffer); |
| n->f_visit_buffer_store = std::move(f_visit_buffer_store); |
| n->f_visit_buffer_realize = std::move(f_visit_buffer_realize); |
| n->f_visit_assert_stmt = std::move(f_visit_assert_stmt); |
| n->f_visit_seq_stmt = std::move(f_visit_seq_stmt); |
| n->f_visit_evaluate = std::move(f_visit_evaluate); |
| n->f_visit_block = std::move(f_visit_block); |
| n->f_visit_block_realize = std::move(f_visit_block_realize); |
| // Expression functions |
| n->f_visit_var = std::move(f_visit_var); |
| n->f_visit_size_var = std::move(f_visit_size_var); |
| n->f_visit_buffer_load = std::move(f_visit_buffer_load); |
| n->f_visit_producer_load = std::move(f_visit_producer_load); |
| n->f_visit_let = std::move(f_visit_let); |
| n->f_visit_call = std::move(f_visit_call); |
| n->f_visit_add = std::move(f_visit_add); |
| n->f_visit_sub = std::move(f_visit_sub); |
| n->f_visit_mul = std::move(f_visit_mul); |
| n->f_visit_div = std::move(f_visit_div); |
| n->f_visit_mod = std::move(f_visit_mod); |
| n->f_visit_floor_div = std::move(f_visit_floor_div); |
| n->f_visit_floor_mod = std::move(f_visit_floor_mod); |
| n->f_visit_min = std::move(f_visit_min); |
| n->f_visit_max = std::move(f_visit_max); |
| n->f_visit_eq = std::move(f_visit_eq); |
| n->f_visit_ne = std::move(f_visit_ne); |
| n->f_visit_lt = std::move(f_visit_lt); |
| n->f_visit_le = std::move(f_visit_le); |
| n->f_visit_gt = std::move(f_visit_gt); |
| n->f_visit_ge = std::move(f_visit_ge); |
| n->f_visit_and = std::move(f_visit_and); |
| n->f_visit_or = std::move(f_visit_or); |
| n->f_visit_reduce = std::move(f_visit_reduce); |
| n->f_visit_cast = std::move(f_visit_cast); |
| n->f_visit_not = std::move(f_visit_not); |
| n->f_visit_select = std::move(f_visit_select); |
| n->f_visit_ramp = std::move(f_visit_ramp); |
| n->f_visit_broadcast = std::move(f_visit_broadcast); |
| n->f_visit_shuffle = std::move(f_visit_shuffle); |
| n->f_visit_int_imm = std::move(f_visit_int_imm); |
| n->f_visit_float_imm = std::move(f_visit_float_imm); |
| n->f_visit_string_imm = std::move(f_visit_string_imm); |
| return PyStmtExprMutator(n); |
| } |
| |
| TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PyStmtExprMutator, ObjectRef, |
| PyStmtExprMutatorNode); |
| }; |
| |
| // ================================================ |
| // TVM Register |
| // ================================================ |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| PyStmtExprVisitorNode::RegisterReflection(); |
| PyStmtExprMutatorNode::RegisterReflection(); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("tir.MakePyStmtExprVisitor", PyStmtExprVisitor::MakePyStmtExprVisitor) |
| .def("tir.MakePyStmtExprMutator", PyStmtExprMutator::MakePyStmtExprMutator); |
| } |
| |
| // StmtExprVisitor |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("tir.PyStmtExprVisitorDefaultVisitExpr", |
| [](PyStmtExprVisitor visitor, const PrimExpr& expr) { visitor->DefaultVisitExpr(expr); }) |
| .def("tir.PyStmtExprVisitorDefaultVisitStmt", |
| [](PyStmtExprVisitor visitor, const Stmt& stmt) { visitor->DefaultVisitStmt(stmt); }) |
| .def("tir.PyStmtExprVisitorVisitStmt", |
| [](PyStmtExprVisitor visitor, const Stmt& stmt) { visitor->VisitStmt(stmt); }) |
| .def("tir.PyStmtExprVisitorVisitExpr", |
| [](PyStmtExprVisitor visitor, const PrimExpr& expr) { visitor->VisitExpr(expr); }); |
| } |
| |
| // StmtExprMutator |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("tir.PyStmtExprMutatorDefaultVisitExpr", |
| [](PyStmtExprMutator mutator, const PrimExpr& expr) { |
| return mutator->DefaultVisitExpr(expr); |
| }) |
| .def("tir.PyStmtExprMutatorDefaultVisitStmt", |
| [](PyStmtExprMutator mutator, const Stmt& stmt) { |
| return mutator->DefaultVisitStmt(stmt); |
| }) |
| .def("tir.PyStmtExprMutatorVisitExpr", |
| [](PyStmtExprMutator mutator, const PrimExpr& expr) { return mutator->VisitExpr(expr); }) |
| .def("tir.PyStmtExprMutatorVisitStmt", |
| [](PyStmtExprMutator mutator, const Stmt& stmt) { return mutator->VisitStmt(stmt); }); |
| } |
| |
| } // namespace tir |
| } // namespace tvm |