blob: 9ddffe19a878f79365b377bc0917e975d6fced0d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../../../tir/transform/ir_utils.h" // For `GetPtrStorageScope`
#include "./utils.h"
namespace tvm {
namespace script {
namespace printer {
Doc DoConciseScoping(const ffi::Optional<ExprDoc>& lhs, const ExprDoc& rhs,
ffi::Array<StmtDoc>* stmts, bool concise_scoping) {
if (concise_scoping) {
if (lhs.defined()) {
stmts->insert(stmts->begin(), AssignDoc(lhs.value(), rhs, std::nullopt));
} else {
stmts->insert(stmts->begin(), ExprStmtDoc(rhs));
}
return StmtBlockDoc(*stmts);
} else {
return ScopeDoc(lhs, rhs, *stmts);
}
}
bool AllowConciseScoping(const IRDocsifier& d, const ObjectRef& obj) {
if (d->cfg.defined()) {
if (d->cfg->obj_to_annotate.count(obj)) {
// if the object requires annotation, do not fold this frame
return false;
}
}
ICHECK(!d->frames.empty());
if (const auto* f = d->frames.back().as<TIRFrameNode>()) {
return f->allow_concise_scoping;
}
LOG(FATAL) << "NotImplementedError: fragment printing";
}
bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IRDocsifier& d) {
if (!d->common_prefix.count(var.get())) {
return false;
}
const std::vector<const Object*>& path = d->common_prefix.at(var.get());
for (auto it = path.rbegin(); it != path.rend(); ++it) {
if (*it == node.get()) {
return true;
}
}
return false;
}
ffi::Optional<PrimExpr> FindReturnValue(const tir::Stmt& node) {
auto eval = node.as<tir::EvaluateNode>();
if (!eval) return std::nullopt;
auto call = eval->value.as<tir::CallNode>();
if (!call) return std::nullopt;
if (!call->op.same_as(tir::builtin::ret())) return std::nullopt;
if (call->args.size() != 1) return std::nullopt;
return call->args[0];
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Evaluate>("", [](tir::Evaluate eval, AccessPath p, IRDocsifier d) -> Doc {
if (d->cfg->syntax_sugar) {
if (auto return_value = FindReturnValue(eval)) {
ExprDoc value =
d->AsDoc<ExprDoc>(return_value.value(), p->Attr("value")->Attr("args")->ArrayItem(0));
return ReturnDoc(value);
}
}
ExprDoc value = d->AsDoc<ExprDoc>(eval->value, p->Attr("value"));
if (eval->value->IsInstance<tir::CallNode>()) {
return ExprStmtDoc(value);
}
return ExprStmtDoc(TIR(d, "evaluate")->Call({value}));
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::LetStmt>("", [](tir::LetStmt stmt, AccessPath p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d, stmt);
// Step 1. Type annotation
ffi::Optional<ExprDoc> type_doc = d->AsDoc<ExprDoc>(stmt->var->type_annotation, //
p->Attr("var")->Attr("type_annotation"));
if (const auto* tuple_type = stmt->var->type_annotation.as<TupleTypeNode>()) {
if (tuple_type->fields.empty()) {
type_doc = std::nullopt;
}
}
// Step 2. RHS
ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
// Step 3. LHS and body
With<TIRFrame> f(d, stmt);
ffi::Array<StmtDoc>* stmts = &(*f)->stmts;
bool var_defined = d->IsVarDefined(stmt->var);
if (!var_defined) {
DefineVar(stmt->var, *f, d);
}
ExprDoc lhs = d->AsDoc<ExprDoc>(stmt->var, p->Attr("var"));
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
// Step 4. Dispatch
if (var_defined) {
return ScopeDoc(std::nullopt, TIR(d, "LetStmt")->Call({rhs}, {"var"}, {lhs}), *stmts);
} else if (concise) {
stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc));
return StmtBlockDoc(*stmts);
} else if (type_doc.defined() && !stmt->var->type_annotation->IsInstance<PrimTypeNode>()) {
return ScopeDoc(lhs, TIR(d, "LetStmt")->Call({rhs, type_doc.value()}), *stmts);
} else {
return ScopeDoc(lhs, TIR(d, "LetStmt")->Call({rhs}), *stmts);
}
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::AssertStmt>(
"", [](tir::AssertStmt stmt, AccessPath p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d, stmt);
ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"));
ExprDoc msg = d->AsDoc<ExprDoc>(stmt->message, p->Attr("message"));
With<TIRFrame> f(d, stmt);
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
if (concise) {
ffi::Array<StmtDoc>* stmts = &(*f)->stmts;
stmts->insert(stmts->begin(), AssertDoc(cond, msg));
return StmtBlockDoc(*stmts);
}
return ScopeDoc(std::nullopt, TIR(d, "Assert")->Call({cond, msg}), (*f)->stmts);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::While>("", [](tir::While stmt, AccessPath p, IRDocsifier d) -> Doc {
ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"));
With<TIRFrame> f(d, stmt);
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
return WhileDoc(cond, (*f)->stmts);
});
namespace {
Doc DeclBufferDoc(tir::DeclBuffer stmt, AccessPath p, IRDocsifier d,
BufferVarDefinition var_definitions) {
bool concise = AllowConciseScoping(d, stmt);
ExprDoc rhs = BufferDecl(stmt->buffer, "decl_buffer", {}, p->Attr("buffer"), d->frames.back(), d,
var_definitions);
With<TIRFrame> f(d, stmt);
ExprDoc lhs = DefineBuffer(stmt->buffer, *f, d);
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise);
}
} // namespace
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::DeclBuffer>( //
"", [](tir::DeclBuffer stmt, AccessPath p, IRDocsifier d) -> Doc {
return DeclBufferDoc(stmt, p, d, BufferVarDefinition::None);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::IfThenElse>( //
"", [](tir::IfThenElse stmt, AccessPath p, IRDocsifier d) -> Doc {
ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"));
ffi::Array<StmtDoc> then_branch;
ffi::Array<StmtDoc> else_branch;
if (stmt->then_case.defined()) {
With<TIRFrame> f(d, stmt->then_case);
AsDocBody(stmt->then_case, p->Attr("then_case"), f->get(), d);
then_branch = (*f)->stmts;
}
if (stmt->else_case.defined()) {
With<TIRFrame> f(d, stmt->else_case);
AsDocBody(stmt->else_case.value(), p->Attr("else_case"), f->get(), d);
else_branch = (*f)->stmts;
}
return IfDoc(cond, then_branch, else_branch);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::SeqStmt>("", [](tir::SeqStmt stmt, AccessPath p, IRDocsifier d) -> Doc {
With<TIRFrame> f(d, stmt);
AsDocBody(stmt, p, f->get(), d);
return StmtBlockDoc((*f)->stmts);
});
bool IsAllocateDeclBufferPattern(const tir::AllocateNode* allocate) {
const tir::Var& buffer_var = allocate->buffer_var;
if (const tir::DeclBufferNode* decl_buffer = allocate->body.as<tir::DeclBufferNode>()) {
const tir::Buffer& buffer = decl_buffer->buffer;
if (buffer_var.same_as(buffer->data) && allocate->dtype == buffer->dtype &&
tir::is_one(allocate->condition) && !allocate->annotations.size() &&
allocate->extents.size() == buffer->shape.size()) {
tir::ExprDeepEqual expr_equal;
for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
return false;
}
}
return true;
}
}
return false;
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Allocate>( //
"", [](tir::Allocate stmt, AccessPath stmt_p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d, stmt_p);
if (d->cfg->syntax_sugar && IsAllocateDeclBufferPattern(stmt.get())) {
return DeclBufferDoc(Downcast<tir::DeclBuffer>(stmt->body), stmt_p->Attr("body"), d,
BufferVarDefinition::DataPointer);
}
ffi::Array<ExprDoc> args;
ffi::Array<ffi::String> kwargs_keys;
ffi::Array<ExprDoc> kwargs_values;
args.push_back(d->AsDoc<ExprDoc>(stmt->extents, stmt_p->Attr("extents")));
args.push_back(LiteralDoc::DataType(stmt->dtype, stmt_p->Attr("dtype")));
args.push_back(LiteralDoc::Str(tir::GetPtrStorageScope(stmt->buffer_var),
stmt_p
->Attr("buffer_var") //
->Attr("type_annotation")
->Attr("storage_scope")));
if (!tir::is_one(stmt->condition)) {
args.push_back(d->AsDoc<ExprDoc>(stmt->condition, stmt_p->Attr("condition")));
}
if (!stmt->annotations.empty()) {
kwargs_keys.push_back("annotations");
kwargs_values.push_back(
d->AsDoc<ExprDoc>(stmt->annotations, stmt_p->Attr("annotations")));
}
ExprDoc lhs = DefineVar(stmt->buffer_var, d->frames.back(), d);
With<TIRFrame> f(d, stmt);
ExprDoc rhs = TIR(d, "allocate")->Call(args, kwargs_keys, kwargs_values);
AsDocBody(stmt->body, stmt_p->Attr("body"), f->get(), d);
return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise);
});
template <typename T>
ExprDoc PrintTensor(::tvm::runtime::Tensor arr) {
// FIXME(@junrushao): this is a hack and can be wrong in most of the cases
constexpr int NUM_PRINT = 200;
int ndim = arr->ndim;
int tot_dim = 1;
for (int i = 0; i < ndim; i++) {
tot_dim *= arr->shape[i];
}
ffi::Array<ExprDoc> result;
T* data_ptr = reinterpret_cast<T*>(arr->data);
runtime::DataType dtype = arr.DataType();
for (int i = 0; i < tot_dim; i++) {
if (dtype.is_float()) {
result.push_back(LiteralDoc::Float(data_ptr[i], std::nullopt));
} else {
result.push_back(LiteralDoc::Int(data_ptr[i], std::nullopt));
}
if (i == NUM_PRINT) {
break;
}
}
return ListDoc(result);
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::AllocateConst>(
"", [](tir::AllocateConst stmt, AccessPath stmt_p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d, stmt);
ffi::String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var);
ffi::Array<ExprDoc> args;
ffi::Array<ffi::String> kwargs_keys;
ffi::Array<ExprDoc> kwargs_values;
ExprDoc data_doc{ffi::UnsafeInit()};
if (stmt->dtype.is_int()) {
if (stmt->dtype.bits() == 8) {
data_doc = PrintTensor<int8_t>(stmt->data.value());
} else if (stmt->dtype.bits() == 16) {
data_doc = PrintTensor<int16_t>(stmt->data.value());
} else if (stmt->dtype.bits() == 32) {
data_doc = PrintTensor<int32_t>(stmt->data.value());
} else if (stmt->dtype.bits() == 64) {
data_doc = PrintTensor<int64_t>(stmt->data.value());
} else {
LOG(FATAL) << "DataType not supported";
}
} else if (stmt->dtype.is_uint()) {
if (stmt->dtype.bits() == 8) {
data_doc = PrintTensor<uint8_t>(stmt->data.value());
} else if (stmt->dtype.bits() == 16) {
data_doc = PrintTensor<uint16_t>(stmt->data.value());
} else if (stmt->dtype.bits() == 32) {
data_doc = PrintTensor<uint32_t>(stmt->data.value());
} else if (stmt->dtype.bits() == 64) {
data_doc = PrintTensor<uint64_t>(stmt->data.value());
} else {
LOG(FATAL) << "DataType not supported";
}
} else if (stmt->dtype.is_float()) {
if (stmt->dtype.bits() == 16) {
data_doc = PrintTensor<int16_t>(stmt->data.value());
} else if (stmt->dtype.bits() == 32) {
data_doc = PrintTensor<float>(stmt->data.value());
} else if (stmt->dtype.bits() == 64) {
data_doc = PrintTensor<double>(stmt->data.value());
} else {
LOG(FATAL) << "DataType not supported";
}
} else {
LOG(FATAL) << "DataType not supported";
}
args.push_back(data_doc);
args.push_back(LiteralDoc::DataType(stmt->dtype, stmt_p->Attr("dtype")));
args.push_back(d->AsDoc<ExprDoc>(stmt->extents, stmt_p->Attr("extents")));
ExprDoc rhs = TIR(d, "allocate_const")->Call(args, kwargs_keys, kwargs_values);
With<TIRFrame> f(d, stmt);
ExprDoc lhs = DefineVar(stmt->buffer_var, *f, d);
AsDocBody(stmt->body, stmt_p->Attr("body"), f->get(), d);
return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise);
});
ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, ffi::Optional<ExprDoc> value, //
AccessPath p, IRDocsifier d) {
ExprDoc buffer = d->AsDoc<ExprDoc>(stmt->buffer, p->Attr("buffer"));
{
ffi::Array<Doc> bounds;
bounds.reserve(stmt->bounds.size());
for (int i = 0, n = stmt->bounds.size(); i < n; ++i) {
Range range = stmt->bounds[i];
AccessPath range_p = p->Attr("bounds")->ArrayItem(i);
bounds.push_back(
SliceDoc(d->AsDoc<ExprDoc>(range->min, range_p->Attr("min")),
d->AsDoc<ExprDoc>(range->min + range->extent, range_p->Attr("extent")), //
std::nullopt));
}
buffer = buffer[bounds];
}
ffi::Array<ExprDoc> args{buffer};
ffi::Array<ffi::String> kwargs_keys;
ffi::Array<ExprDoc> kwargs_values;
if (value.defined()) {
args.push_back(value.value());
}
if (!tir::is_one(stmt->condition)) {
kwargs_keys.push_back("condition");
kwargs_values.push_back(d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition")));
}
return TIR(d, "realize")->Call(args, kwargs_keys, kwargs_values);
}
void InsertEnvThread(const tir::IterVar& iter_var, const AccessPath& iter_var_p,
const IRDocsifier& d) {
Frame f = FindLowestVarDef(iter_var->var, d).value();
DefineVar(iter_var->var, f, d);
ExprDoc rhs = TIR(d, "env_thread")
->Call({LiteralDoc::Str(iter_var->thread_tag, //
iter_var_p->Attr("thread_tag"))});
ExprDoc lhs = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
f->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt));
}
ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const AccessPath& attr_stmt_p,
ffi::Optional<tir::Var>* define_var, const IRDocsifier& d) {
tir::IterVar iter_var = Downcast<tir::IterVar>(attr_stmt->node);
AccessPath iter_var_p = attr_stmt_p->Attr("node");
ExprDoc var_doc{ffi::UnsafeInit()};
if (d->IsVarDefined(iter_var->var)) {
var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
} else if (IsAncestorOfAllVarUse(attr_stmt, iter_var->var, d)) {
var_doc = LiteralDoc::Str(iter_var->thread_tag, iter_var_p->Attr("thread_tag"));
*define_var = iter_var->var;
} else {
InsertEnvThread(iter_var, iter_var_p, d);
var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
}
return TIR(d, "launch_thread")
->Call({
var_doc,
d->AsDoc<ExprDoc>(attr_stmt->value, attr_stmt_p->Attr("value")),
});
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::BufferRealize>( //
"", [](tir::BufferRealize stmt, AccessPath p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d, stmt);
ExprDoc rhs = DocsifyBufferRealize(stmt.get(), std::nullopt, p, d);
With<TIRFrame> f(d, stmt);
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
return DoConciseScoping(std::nullopt, rhs, &(*f)->stmts, concise);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::AttrStmt>( //
"", [](tir::AttrStmt stmt, AccessPath stmt_p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d, stmt);
ffi::Optional<ExprDoc> lhs = std::nullopt;
ffi::Optional<ExprDoc> rhs = std::nullopt;
ffi::Optional<tir::Var> define_var = std::nullopt;
tir::Stmt body = stmt->body;
AccessPath body_p = stmt_p->Attr("body");
if (stmt->attr_key == "realize_scope") {
if (const auto* realize = stmt->body.as<tir::BufferRealizeNode>()) {
// TODO(tqchen): add any.same_as(ObjectRef)
if (realize->buffer.same_as(stmt->node.cast<ObjectRef>())) {
rhs = DocsifyBufferRealize(
realize,
/*value=*/d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value")),
/*p=*/stmt_p->Attr("body"), d);
body = realize->body;
body_p = stmt_p->Attr("body")->Attr("body");
}
}
}
if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread") {
if (stmt->node.as<tir::IterVarNode>()) {
rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d);
}
}
if (!rhs.defined()) {
rhs = TIR(d, "attr")->Call({
d->AsDoc<ExprDoc>(stmt->node, stmt_p->Attr("node")),
LiteralDoc::Str(stmt->attr_key, stmt_p->Attr("attr_key")),
d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value")),
});
}
With<TIRFrame> f(d, stmt);
if (define_var.defined()) {
lhs = DefineVar(define_var.value(), *f, d);
}
AsDocBody(body, body_p, f->get(), d);
return DoConciseScoping(lhs, rhs.value(), &(*f)->stmts, concise);
});
TVM_SCRIPT_REPR(tir::LetStmtNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::AttrStmtNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::AssertStmtNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::WhileNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::AllocateNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::AllocateConstNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::DeclBufferNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::SeqStmtNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::IfThenElseNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::EvaluateNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::BufferRealizeNode, ReprPrintTIR);
} // namespace printer
} // namespace script
} // namespace tvm