| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file tvm/tir/stmt.cc |
| */ |
| #include <tvm/runtime/registry.h> |
| #include <tvm/tir/op.h> |
| #include <tvm/tir/op_attr_types.h> |
| #include <tvm/tir/stmt.h> |
| |
| namespace tvm { |
| namespace tir { |
| |
| // LetStmt |
| LetStmt::LetStmt(Var var, PrimExpr value, Stmt body) { |
| CHECK(value.defined()); |
| CHECK(body.defined()); |
| CHECK_EQ(value.dtype(), var.dtype()); |
| |
| ObjectPtr<LetStmtNode> node = make_object<LetStmtNode>(); |
| node->var = std::move(var); |
| node->value = std::move(value); |
| node->body = std::move(body); |
| data_ = std::move(node); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.LetStmt").set_body_typed([](Var var, PrimExpr value, Stmt body) { |
| return LetStmt(var, value, body); |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(LetStmtNode); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<LetStmtNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const LetStmtNode*>(node.get()); |
| p->PrintIndent(); |
| p->stream << "let " << op->var << " = "; |
| p->Print(op->value); |
| p->stream << '\n'; |
| p->Print(op->body); |
| }); |
| |
| // AttrStmt |
| AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body) { |
| auto n = make_object<AttrStmtNode>(); |
| n->node = node; |
| n->attr_key = std::move(attr_key); |
| n->value = std::move(value); |
| n->body = std::move(body); |
| data_ = std::move(n); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.AttrStmt") |
| .set_body_typed([](ObjectRef node, String attr_key, PrimExpr value, Stmt body) { |
| return AttrStmt(node, attr_key, value, body); |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(AttrStmtNode); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<AttrStmtNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const AttrStmtNode*>(node.get()); |
| p->PrintIndent(); |
| p->stream << "// attr ["; |
| p->Print(op->node); |
| p->stream << "] " << op->attr_key << " = "; |
| p->Print(op->value); |
| p->stream << '\n'; |
| p->Print(op->body); |
| }); |
| |
| // AssertStmt |
| AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body) { |
| CHECK(condition.defined()); |
| CHECK(message.dtype() == DataType::Int(32) || message.as<StringImmNode>()) |
| << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; |
| |
| ObjectPtr<AssertStmtNode> node = make_object<AssertStmtNode>(); |
| node->condition = std::move(condition); |
| node->message = std::move(message); |
| node->body = std::move(body); |
| data_ = std::move(node); |
| } |
| |
| TVM_REGISTER_NODE_TYPE(AssertStmtNode); |
| |
| TVM_REGISTER_GLOBAL("tir.AssertStmt") |
| .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) { |
| if (const auto* str = message.as<StringObj>()) { |
| auto msg = StringImm(str->data); |
| return AssertStmt(condition, msg, body); |
| } else { |
| return AssertStmt(condition, Downcast<PrimExpr>(message), body); |
| } |
| }); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<AssertStmtNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const AssertStmtNode*>(node.get()); |
| p->PrintIndent(); |
| p->stream << "assert("; |
| p->Print(op->condition); |
| p->stream << ", "; |
| p->Print(op->message); |
| p->stream << ")\n"; |
| p->Print(op->body); |
| }); |
| |
| // For |
| For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api, |
| Stmt body) { |
| CHECK(min.defined()); |
| CHECK(extent.defined()); |
| CHECK(min.dtype().is_scalar()); |
| CHECK(extent.dtype().is_scalar()); |
| CHECK(loop_var.dtype().is_scalar()); |
| CHECK(body.defined()); |
| |
| ObjectPtr<ForNode> node = make_object<ForNode>(); |
| node->loop_var = std::move(loop_var); |
| node->min = std::move(min); |
| node->extent = std::move(extent); |
| node->for_type = for_type; |
| node->device_api = device_api; |
| node->body = std::move(body); |
| data_ = std::move(node); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.For").set_body_typed([](Var loop_var, PrimExpr min, PrimExpr extent, |
| int for_type, int device_api, Stmt body) { |
| return For(loop_var, min, extent, static_cast<ForType>(for_type), |
| static_cast<DeviceAPI>(device_api), body); |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(ForNode); |
| |
| std::ostream& operator<<(std::ostream& out, ForType type) { // NOLINT(*) |
| switch (type) { |
| case ForType::Serial: |
| out << "for"; |
| break; |
| case ForType::Parallel: |
| out << "parallel"; |
| break; |
| case ForType::Unrolled: |
| out << "unrolled"; |
| break; |
| case ForType::Vectorized: |
| out << "vectorized"; |
| break; |
| } |
| return out; |
| } |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<ForNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const ForNode*>(node.get()); |
| p->PrintIndent(); |
| p->stream << op->for_type << " (" << op->loop_var << ", "; |
| p->Print(op->min); |
| p->stream << ", "; |
| p->Print(op->extent); |
| p->stream << ") {\n"; |
| |
| p->indent += 2; |
| p->Print(op->body); |
| p->indent -= 2; |
| |
| p->PrintIndent(); |
| p->stream << "}\n"; |
| }); |
| |
| // Store |
| Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) { |
| CHECK(value.defined()); |
| CHECK(index.defined()); |
| CHECK(predicate.defined()); |
| CHECK_EQ(value.dtype().lanes(), index.dtype().lanes()); |
| CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes()); |
| |
| ObjectPtr<StoreNode> node = make_object<StoreNode>(); |
| node->buffer_var = std::move(buffer_var); |
| node->value = std::move(value); |
| node->index = std::move(index); |
| node->predicate = std::move(predicate); |
| data_ = std::move(node); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) { |
| PrimExpr value = args[1]; |
| if (args.size() == 3) { |
| *ret = Store(args[0], value, args[2], const_true(value.dtype().lanes())); |
| } else { |
| *ret = Store(args[0], value, args[2], args[3]); |
| } |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(StoreNode); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<StoreNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const StoreNode*>(node.get()); |
| p->PrintIndent(); |
| p->stream << op->buffer_var << "["; |
| p->Print(op->index); |
| p->stream << "] = "; |
| p->Print(op->value); |
| if (!is_one(op->predicate)) { |
| p->stream << " if "; |
| p->Print(op->predicate); |
| } |
| p->stream << '\n'; |
| }); |
| |
| // ProducerStore |
| ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices) { |
| ObjectPtr<ProducerStoreNode> node = make_object<ProducerStoreNode>(); |
| node->producer = std::move(producer); |
| node->value = std::move(value); |
| node->indices = std::move(indices); |
| data_ = std::move(node); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.ProducerStore") |
| .set_body_typed([](DataProducer producer, PrimExpr value, Array<PrimExpr> indices) { |
| return ProducerStore(producer, value, indices); |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(ProducerStoreNode); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<ProducerStoreNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const ProducerStoreNode*>(node.get()); |
| p->PrintIndent(); |
| p->stream << op->producer->GetNameHint() << "["; |
| for (size_t i = 0; i < op->indices.size(); ++i) { |
| p->Print(op->indices[i]); |
| if (i < op->indices.size() - 1) p->stream << ", "; |
| } |
| p->stream << "]"; |
| p->stream << " ="; |
| p->Print(op->value); |
| p->stream << '\n'; |
| }); |
| |
| // Allocate |
| Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition, |
| Stmt body) { |
| // TODO(tvm-team): Add invariant check to make sure |
| // IsPointerPType(buffer_var->type_annotation, dtype) |
| // once we fix the allocate tvm script printing. |
| for (size_t i = 0; i < extents.size(); ++i) { |
| CHECK(extents[i].defined()); |
| CHECK(extents[i].dtype().is_scalar()); |
| } |
| CHECK(body.defined()); |
| CHECK(condition.defined()); |
| CHECK(condition.dtype().is_bool()); |
| |
| ObjectPtr<AllocateNode> node = make_object<AllocateNode>(); |
| node->buffer_var = std::move(buffer_var); |
| node->dtype = dtype; |
| node->extents = std::move(extents); |
| node->condition = std::move(condition); |
| node->body = std::move(body); |
| data_ = std::move(node); |
| } |
| |
| int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) { |
| int64_t result = 1; |
| for (size_t i = 0; i < extents.size(); ++i) { |
| if (const IntImmNode* int_size = extents[i].as<IntImmNode>()) { |
| result *= int_size->value; |
| if (result > std::numeric_limits<int32_t>::max()) { |
| return 0; |
| } |
| } else { |
| return 0; |
| } |
| } |
| return static_cast<int32_t>(result); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.Allocate") |
| .set_body_typed([](Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, |
| Stmt body) { return Allocate(buffer_var, type, extents, condition, body); }); |
| |
| TVM_REGISTER_NODE_TYPE(AllocateNode); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<AllocateNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const AllocateNode*>(node.get()); |
| p->PrintIndent(); |
| p->stream << "allocate " << op->buffer_var << "[" << op->dtype; |
| for (size_t i = 0; i < op->extents.size(); ++i) { |
| p->stream << " * "; |
| p->Print(op->extents[i]); |
| } |
| p->stream << "]"; |
| if (!is_one(op->condition)) { |
| p->stream << " if "; |
| p->Print(op->condition); |
| } |
| p->stream << "\n"; |
| p->Print(op->body); |
| }); |
| |
| // ProducerRealize |
| ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, |
| Stmt body) { |
| for (size_t i = 0; i < bounds.size(); ++i) { |
| CHECK(bounds[i]->min.defined()); |
| CHECK(bounds[i]->extent.defined()); |
| CHECK(bounds[i]->min.dtype().is_scalar()); |
| CHECK(bounds[i]->extent.dtype().is_scalar()); |
| } |
| CHECK(body.defined()); |
| CHECK(condition.defined()); |
| CHECK(condition.dtype().is_bool()); |
| |
| ObjectPtr<ProducerRealizeNode> node = make_object<ProducerRealizeNode>(); |
| node->producer = std::move(producer); |
| node->bounds = std::move(bounds); |
| node->condition = std::move(condition); |
| node->body = std::move(body); |
| data_ = std::move(node); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.ProducerRealize") |
| .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body) { |
| return ProducerRealize(producer, bounds, condition, body); |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<ProducerRealizeNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const ProducerRealizeNode*>(node.get()); |
| p->PrintIndent(); |
| p->stream << "producer_realize " << op->producer->GetNameHint() << "("; |
| for (size_t i = 0; i < op->bounds.size(); ++i) { |
| p->stream << "["; |
| p->Print(op->bounds[i]->min); |
| p->stream << ", "; |
| p->Print(op->bounds[i]->extent); |
| p->stream << "]"; |
| if (i < op->bounds.size() - 1) p->stream << ", "; |
| } |
| p->stream << ")"; |
| if (!is_one(op->condition)) { |
| p->stream << " if "; |
| p->Print(op->condition); |
| } |
| p->stream << " {\n"; |
| |
| p->indent += 2; |
| p->Print(op->body); |
| p->indent -= 2; |
| |
| p->PrintIndent(); |
| p->stream << "}\n"; |
| }); |
| |
| // Prefetch |
| Prefetch::Prefetch(Buffer buffer, Array<Range> bounds) { |
| data_ = make_object<PrefetchNode>(buffer, bounds); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.Prefetch").set_body_typed([](Buffer buffer, Array<Range> bounds) { |
| return Prefetch(buffer, bounds); |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(PrefetchNode); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<PrefetchNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const PrefetchNode*>(node.get()); |
| p->PrintIndent(); |
| p->stream << "prefetch " << op->buffer << "("; |
| for (size_t i = 0; i < op->bounds.size(); ++i) { |
| p->stream << "["; |
| p->Print(op->bounds[i]->min); |
| p->stream << ", "; |
| p->Print(op->bounds[i]->extent); |
| p->stream << "]"; |
| if (i < op->bounds.size() - 1) p->stream << ", "; |
| } |
| p->stream << ")"; |
| }); |
| |
| // SeqStmt |
| SeqStmt::SeqStmt(Array<Stmt> seq) { |
| auto node = make_object<SeqStmtNode>(); |
| node->seq = std::move(seq); |
| data_ = std::move(node); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array<Stmt> seq) { |
| return SeqStmt(std::move(seq)); |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(SeqStmtNode); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<SeqStmtNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const SeqStmtNode*>(node.get()); |
| for (Stmt stmt : op->seq) { |
| p->Print(stmt); |
| } |
| }); |
| |
| // IfThenElse |
| IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case) { |
| CHECK(condition.defined()); |
| CHECK(then_case.defined()); |
| // else_case may be null. |
| ObjectPtr<IfThenElseNode> node = make_object<IfThenElseNode>(); |
| node->condition = std::move(condition); |
| node->then_case = std::move(then_case); |
| node->else_case = std::move(else_case); |
| data_ = std::move(node); |
| } |
| |
| TVM_REGISTER_NODE_TYPE(IfThenElseNode); |
| |
| TVM_REGISTER_GLOBAL("tir.IfThenElse") |
| .set_body_typed([](PrimExpr condition, Stmt then_case, Stmt else_case) { |
| return IfThenElse(condition, then_case, else_case); |
| }); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<IfThenElseNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const IfThenElseNode*>(node.get()); |
| p->PrintIndent(); |
| while (true) { |
| p->stream << "if (" << op->condition << ") {\n"; |
| p->indent += 2; |
| p->Print(op->then_case); |
| p->indent -= 2; |
| |
| if (!op->else_case.defined()) { |
| break; |
| } |
| |
| if (const IfThenElseNode* nested_if = op->else_case.as<IfThenElseNode>()) { |
| p->PrintIndent(); |
| p->stream << "} else "; |
| op = nested_if; |
| } else { |
| p->PrintIndent(); |
| p->stream << "} else {\n"; |
| p->indent += 2; |
| p->Print(op->else_case); |
| p->indent -= 2; |
| break; |
| } |
| } |
| p->PrintIndent(); |
| p->stream << "}\n"; |
| }); |
| |
| // Evaluate |
| Evaluate::Evaluate(PrimExpr value) { |
| CHECK(value.defined()); |
| |
| ObjectPtr<EvaluateNode> node = make_object<EvaluateNode>(); |
| node->value = std::move(value); |
| data_ = std::move(node); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value) { return Evaluate(value); }); |
| |
| TVM_REGISTER_NODE_TYPE(EvaluateNode); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<EvaluateNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const EvaluateNode*>(node.get()); |
| p->PrintIndent(); |
| p->Print(op->value); |
| p->stream << "\n"; |
| }); |
| |
| // BufferStore |
| BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) { |
| ObjectPtr<BufferStoreNode> node = make_object<BufferStoreNode>(); |
| node->buffer = std::move(buffer); |
| node->value = std::move(value); |
| node->indices = std::move(indices); |
| data_ = std::move(node); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.BufferStore") |
| .set_body_typed([](Buffer buffer, PrimExpr value, Array<PrimExpr> indices) { |
| return BufferStore(buffer, value, indices); |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(BufferStoreNode); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<BufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const BufferStoreNode*>(node.get()); |
| p->PrintIndent(); |
| p->stream << op->buffer->name << "["; |
| for (size_t i = 0; i < op->indices.size(); ++i) { |
| p->Print(op->indices[i]); |
| if (i < op->indices.size() - 1) p->stream << ", "; |
| } |
| p->stream << "]"; |
| p->stream << " = "; |
| p->Print(op->value); |
| p->stream << '\n'; |
| }); |
| |
| // BufferRealize |
| BufferRealize::BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body) { |
| data_ = make_object<BufferRealizeNode>(buffer, bounds, condition, body); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.BufferRealize") |
| .set_body_typed([](Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body) { |
| return BufferRealize(buffer, bounds, condition, body); |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(BufferRealizeNode); |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<BufferRealizeNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const BufferRealizeNode*>(node.get()); |
| p->PrintIndent(); |
| p->stream << "buffer_realize " << op->buffer->name << "("; |
| for (size_t i = 0; i < op->bounds.size(); ++i) { |
| p->stream << "["; |
| p->Print(op->bounds[i]->min); |
| p->stream << ", "; |
| p->Print(op->bounds[i]->extent); |
| p->stream << "]"; |
| if (i < op->bounds.size() - 1) p->stream << ", "; |
| } |
| p->stream << ")"; |
| if (!is_one(op->condition)) { |
| p->stream << " if "; |
| p->Print(op->condition); |
| } |
| p->stream << " {\n"; |
| |
| p->indent += 2; |
| p->Print(op->body); |
| p->indent -= 2; |
| |
| p->PrintIndent(); |
| p->stream << "}\n"; |
| }); |
| |
| PrimExpr TypeAnnotation(DataType dtype) { |
| static auto op = Op::Get("tir.type_annotation"); |
| return tir::Call(dtype, op, {}); |
| } |
| |
| TVM_REGISTER_OP("tir.type_annotation") |
| .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure)); |
| |
| } // namespace tir |
| } // namespace tvm |