| /*! |
| * Copyright (c) 2017 by Contributors |
| * Exposre of pass functions. |
| * \file api_pass.cc |
| */ |
| #include <tvm/expr.h> |
| #include <tvm/ir.h> |
| #include <tvm/attrs.h> |
| #include <tvm/ir_pass.h> |
| #include <tvm/ir_visitor.h> |
| #include <tvm/ir_mutator.h> |
| #include <tvm/api_registry.h> |
| |
| namespace tvm { |
| namespace ir { |
| |
| TVM_REGISTER_API("ir_pass.Simplify") |
| .set_body([](TVMArgs args, TVMRetValue *ret) { |
| if (args[0].IsNodeType<Stmt>()) { |
| if (args.size() > 1) { |
| *ret = Simplify(args[0].operator Stmt(), args[1]); |
| } else { |
| *ret = Simplify(args[0].operator Stmt()); |
| } |
| } else { |
| if (args.size() > 1) { |
| *ret = Simplify(args[0].operator Expr(), args[1]); |
| } else { |
| *ret = Simplify(args[0].operator Expr()); |
| } |
| } |
| }); |
| |
| TVM_REGISTER_API("ir_pass.CanonicalSimplify") |
| .set_body([](TVMArgs args, TVMRetValue *ret) { |
| if (args[0].IsNodeType<Stmt>()) { |
| if (args.size() > 1) { |
| *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]); |
| } else { |
| *ret = CanonicalSimplify(args[0].operator Stmt()); |
| } |
| } else { |
| if (args.size() > 1) { |
| *ret = CanonicalSimplify(args[0].operator Expr(), args[1]); |
| } else { |
| *ret = CanonicalSimplify(args[0].operator Expr()); |
| } |
| } |
| }); |
| |
| TVM_REGISTER_API("ir_pass.Substitute") |
| .set_body([](TVMArgs args, TVMRetValue *ret) { |
| if (args[0].IsNodeType<Stmt>()) { |
| *ret = Substitute(args[0].operator Stmt(), args[1].operator Map<Var, Expr>()); |
| } else { |
| *ret = Substitute(args[0].operator Expr(), args[1].operator Map<Var, Expr>()); |
| } |
| }); |
| |
| TVM_REGISTER_API("ir_pass.Equal") |
| .set_body([](TVMArgs args, TVMRetValue *ret) { |
| if (args[0].IsNodeType<Stmt>()) { |
| *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt()); |
| } else { |
| *ret = Equal(args[0].operator Expr(), args[1].operator Expr()); |
| } |
| }); |
| |
| TVM_REGISTER_API("ir_pass.StorageFlatten") |
| .set_body([](TVMArgs args, TVMRetValue *ret) { |
| if (args.size() <= 3) { |
| *ret = StorageFlatten(args[0], args[1], args[2]); |
| } else { |
| *ret = StorageFlatten(args[0], args[1], args[2], args[3]); |
| } |
| }); |
| |
| TVM_REGISTER_API("ir_pass.AttrsEqual") |
| .set_body_typed<bool(const NodeRef&, const NodeRef&)>([](const NodeRef& lhs, const NodeRef& rhs) { |
| return AttrsEqual()(lhs, rhs); |
| }); |
| |
| TVM_REGISTER_API("ir_pass.AttrsHash") |
| .set_body_typed<int64_t(const NodeRef&)>([](const NodeRef &node) { |
| return AttrsHash()(node); |
| }); |
| |
| |
| TVM_REGISTER_API("ir_pass.ExprUseVar") |
| .set_body([](TVMArgs args, TVMRetValue *ret) { |
| *ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var()); |
| }); |
| |
| TVM_REGISTER_API("ir_pass.PostOrderVisit") |
| .set_body([](TVMArgs args, TVMRetValue *ret) { |
| PackedFunc f = args[1]; |
| ir::PostOrderVisit(args[0], [f](const NodeRef& n) { |
| f(n); |
| }); |
| }); |
| |
| // make from two arguments |
| #define REGISTER_PASS1(PassName) \ |
| TVM_REGISTER_API("ir_pass."#PassName) \ |
| .set_body([](TVMArgs args, TVMRetValue *ret) { \ |
| *ret = PassName(args[0]); \ |
| }) \ |
| |
| #define REGISTER_PASS2(PassName) \ |
| TVM_REGISTER_API("ir_pass."#PassName) \ |
| .set_body([](TVMArgs args, TVMRetValue *ret) { \ |
| *ret = PassName(args[0], args[1]); \ |
| }) \ |
| |
| #define REGISTER_PASS3(PassName) \ |
| TVM_REGISTER_API("ir_pass."#PassName) \ |
| .set_body([](TVMArgs args, TVMRetValue *ret) { \ |
| *ret = PassName(args[0], args[1], args[2]); \ |
| }) \ |
| |
| #define REGISTER_PASS4(PassName) \ |
| TVM_REGISTER_API("ir_pass."#PassName) \ |
| .set_body([](TVMArgs args, TVMRetValue *ret) { \ |
| *ret = PassName(args[0], args[1], args[2], args[3]); \ |
| }) \ |
| |
| #define REGISTER_PASS5(PassName) \ |
| TVM_REGISTER_API("ir_pass."#PassName) \ |
| .set_body([](TVMArgs args, TVMRetValue *ret) { \ |
| *ret = PassName(args[0], args[1], args[2], args[3], args[4]); \ |
| }) \ |
| |
| REGISTER_PASS1(ConvertSSA); |
| REGISTER_PASS1(VerifySSA); |
| REGISTER_PASS1(RewriteUnsafeSelect); |
| REGISTER_PASS4(Inline); |
| REGISTER_PASS4(IRTransform); |
| REGISTER_PASS1(VectorizeLoop); |
| REGISTER_PASS5(UnrollLoop); |
| REGISTER_PASS3(InjectCopyIntrin); |
| REGISTER_PASS2(ThreadSync); |
| REGISTER_PASS5(MakeAPI); |
| REGISTER_PASS2(BindDeviceType); |
| REGISTER_PASS1(SplitHostDevice); |
| REGISTER_PASS1(StorageRewrite); |
| REGISTER_PASS1(CoProcSync); |
| REGISTER_PASS1(LowerStorageAccessInfo); |
| REGISTER_PASS1(InjectVirtualThread); |
| REGISTER_PASS1(InjectPrefetch); |
| REGISTER_PASS2(InjectDoubleBuffer); |
| REGISTER_PASS2(LoopPartition); |
| REGISTER_PASS1(RemoveNoOp); |
| REGISTER_PASS2(SplitPipeline); |
| REGISTER_PASS2(LiftAttrScope); |
| REGISTER_PASS1(NarrowChannelAccess); |
| REGISTER_PASS2(LowerThreadAllreduce); |
| REGISTER_PASS2(LowerWarpMemory); |
| REGISTER_PASS2(RemapThreadAxis); |
| REGISTER_PASS2(LowerIntrin); |
| REGISTER_PASS1(LowerTVMBuiltin); |
| REGISTER_PASS1(CombineContextCall); |
| REGISTER_PASS2(VerifyMemory); |
| REGISTER_PASS2(VerifyGPUCode); |
| REGISTER_PASS1(DecorateDeviceScope); |
| REGISTER_PASS1(InstrumentBoundCheckers); |
| } // namespace ir |
| } // namespace tvm |