blob: 1c861f91dce2197c386c70d25add8d196021556b [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../../../tirx/transform/ir_utils.h" // For `GetPtrStorageScope`
#include "./utils.h"
namespace tvm {
namespace script {
namespace printer {
Doc DoConciseScoping(const ffi::Optional<ExprDoc>& lhs, const ExprDoc& rhs,
ffi::Array<StmtDoc>* stmts, bool concise_scoping) {
if (concise_scoping) {
if (lhs.defined()) {
stmts->insert(stmts->begin(), AssignDoc(lhs.value(), rhs, std::nullopt));
} else {
stmts->insert(stmts->begin(), ExprStmtDoc(rhs));
}
return StmtBlockDoc(*stmts);
} else {
return ScopeDoc(lhs, rhs, *stmts);
}
}
bool AllowConciseScoping(const IRDocsifier& d, const ObjectRef& obj) {
if (d->cfg.defined()) {
if (d->cfg->obj_to_annotate.count(obj)) {
// if the object requires annotation, do not fold this frame
return false;
}
}
TVM_FFI_ICHECK(!d->frames.empty());
if (const auto* f = d->frames.back().as<TIRFrameNode>()) {
return f->allow_concise_scoping;
}
TVM_FFI_THROW(NotImplementedError) << "fragment printing";
TVM_FFI_UNREACHABLE();
}
bool IsAncestorOfAllVarUse(const tirx::Stmt& node, const ObjectRef& var, const IRDocsifier& d) {
if (!d->common_prefix.count(var.get())) {
return false;
}
const std::vector<const Object*>& path = d->common_prefix.at(var.get());
for (auto it = path.rbegin(); it != path.rend(); ++it) {
if (*it == node.get()) {
return true;
}
}
return false;
}
ffi::Optional<PrimExpr> FindReturnValue(const tirx::Stmt& node) {
auto eval = node.as<tirx::EvaluateNode>();
if (!eval) return std::nullopt;
auto call = eval->value.as<tirx::CallNode>();
if (!call) return std::nullopt;
if (!call->op.same_as(tirx::builtin::ret())) return std::nullopt;
if (call->args.size() != 1) return std::nullopt;
return call->args[0];
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tirx::Evaluate>("", [](tirx::Evaluate eval, AccessPath p, IRDocsifier d) -> Doc {
if (d->cfg->syntax_sugar) {
if (auto return_value = FindReturnValue(eval)) {
ExprDoc value =
d->AsDoc<ExprDoc>(return_value.value(), p->Attr("value")->Attr("args")->ArrayItem(0));
return ReturnDoc(value);
}
}
ExprDoc value = d->AsDoc<ExprDoc>(eval->value, p->Attr("value"));
if (eval->value->IsInstance<tirx::CallNode>()) {
return ExprStmtDoc(value);
}
return ExprStmtDoc(TIR(d, "evaluate")->Call({value}));
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tirx::Bind>("", [](tirx::Bind stmt, AccessPath p, IRDocsifier d) -> Doc {
// Step 1. Type annotation
ffi::Optional<ExprDoc> type_doc = d->AsDoc<ExprDoc>(stmt->var->type_annotation, //
p->Attr("var")->Attr("type_annotation"));
if (const auto* tuple_type = stmt->var->type_annotation.as<TupleTypeNode>()) {
if (tuple_type->fields.empty()) {
type_doc = std::nullopt;
}
}
// Step 2. RHS
ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
// Step 3. LHS - Bind is flat, define var if new, otherwise just assign
if (!d->IsVarDefined(stmt->var)) {
TVM_FFI_ICHECK(!d->frames.empty());
ExprDoc lhs = DefineVar(stmt->var, d->frames.back(), d);
return AssignDoc(lhs, rhs, type_doc);
} else {
ExprDoc lhs = d->AsDoc<ExprDoc>(stmt->var, p->Attr("var"));
return AssignDoc(lhs, rhs, std::nullopt);
}
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tirx::AssertStmt>(
"", [](tirx::AssertStmt stmt, AccessPath p, IRDocsifier d) -> Doc {
ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"));
// Always emit the canonical tuple form: assert cond, ("Kind", ["part0", "part1", ...])
ffi::Array<ExprDoc> parts;
auto parts_path = p->Attr("message_parts");
for (size_t i = 0; i < stmt->message_parts.size(); ++i) {
parts.push_back(d->AsDoc<ExprDoc>(stmt->message_parts[i], parts_path->ArrayItem(i)));
}
ExprDoc kind_doc = d->AsDoc<ExprDoc>(stmt->error_kind, p->Attr("error_kind"));
return AssertDoc(cond, TupleDoc({kind_doc, ListDoc(parts)}));
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tirx::While>("", [](tirx::While stmt, AccessPath p, IRDocsifier d) -> Doc {
ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"));
With<TIRFrame> f(d, stmt);
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
return WhileDoc(cond, (*f)->stmts);
});
namespace {
Doc DeclBufferDoc(tirx::DeclBuffer stmt, AccessPath p, IRDocsifier d,
BufferVarDefinition var_definitions) {
ExprDoc rhs = BufferDecl(stmt->buffer, "decl_buffer", {}, p->Attr("buffer"), d->frames.back(), d,
var_definitions);
ExprDoc lhs = DefineBuffer(stmt->buffer, d->frames.back(), d);
return AssignDoc(lhs, rhs, std::nullopt);
}
} // namespace
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tirx::DeclBuffer>( //
"", [](tirx::DeclBuffer stmt, AccessPath p, IRDocsifier d) -> Doc {
return DeclBufferDoc(stmt, p, d, BufferVarDefinition::None);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tirx::IfThenElse>( //
"", [](tirx::IfThenElse stmt, AccessPath p, IRDocsifier d) -> Doc {
ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"));
ffi::Array<StmtDoc> then_branch;
ffi::Array<StmtDoc> else_branch;
if (stmt->then_case.defined()) {
With<TIRFrame> f(d, stmt->then_case);
AsDocBody(stmt->then_case, p->Attr("then_case"), f->get(), d);
then_branch = (*f)->stmts;
}
if (stmt->else_case.defined()) {
With<TIRFrame> f(d, stmt->else_case);
AsDocBody(stmt->else_case.value(), p->Attr("else_case"), f->get(), d);
else_branch = (*f)->stmts;
}
return IfDoc(cond, then_branch, else_branch);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tirx::SeqStmt>("", [](tirx::SeqStmt stmt, AccessPath p, IRDocsifier d) -> Doc {
With<TIRFrame> f(d, stmt);
AsDocBody(stmt, p, f->get(), d);
return StmtBlockDoc((*f)->stmts);
});
void InsertEnvThread(const tirx::IterVar& iter_var, const AccessPath& iter_var_p,
const IRDocsifier& d) {
Frame f = FindLowestVarDef(iter_var->var, d).value();
DefineVar(iter_var->var, f, d);
ExprDoc rhs = TIR(d, "env_thread")
->Call({LiteralDoc::Str(iter_var->thread_tag, //
iter_var_p->Attr("thread_tag"))});
ExprDoc lhs = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
f->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt));
}
ExprDoc DocsifyLaunchThread(const tirx::AttrStmt& attr_stmt, const AccessPath& attr_stmt_p,
ffi::Optional<tirx::Var>* define_var, const IRDocsifier& d) {
tirx::IterVar iter_var = Downcast<tirx::IterVar>(attr_stmt->node);
AccessPath iter_var_p = attr_stmt_p->Attr("node");
ExprDoc var_doc{ffi::UnsafeInit()};
if (d->IsVarDefined(iter_var->var)) {
var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
} else if (IsAncestorOfAllVarUse(attr_stmt, iter_var->var, d)) {
var_doc = LiteralDoc::Str(iter_var->thread_tag, iter_var_p->Attr("thread_tag"));
*define_var = iter_var->var;
} else {
InsertEnvThread(iter_var, iter_var_p, d);
var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
}
return TIR(d, "launch_thread")
->Call({
var_doc,
d->AsDoc<ExprDoc>(attr_stmt->value, attr_stmt_p->Attr("value")),
});
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tirx::AttrStmt>( //
"", [](tirx::AttrStmt stmt, AccessPath stmt_p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d, stmt);
ffi::Optional<ExprDoc> lhs = std::nullopt;
ffi::Optional<ExprDoc> rhs = std::nullopt;
ffi::Optional<tirx::Var> define_var = std::nullopt;
tirx::Stmt body = stmt->body;
AccessPath body_p = stmt_p->Attr("body");
if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread") {
if (stmt->node.as<tirx::IterVarNode>()) {
rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d);
}
}
if (!rhs.defined()) {
rhs = TIR(d, "attr")->Call({
d->AsDoc<ExprDoc>(stmt->node, stmt_p->Attr("node")),
LiteralDoc::Str(stmt->attr_key, stmt_p->Attr("attr_key")),
d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value")),
});
}
With<TIRFrame> f(d, stmt);
if (define_var.defined()) {
lhs = DefineVar(define_var.value(), *f, d);
}
AsDocBody(body, body_p, f->get(), d);
return DoConciseScoping(lhs, rhs.value(), &(*f)->stmts, concise);
});
TVM_SCRIPT_REPR(tirx::BindNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tirx::AttrStmtNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tirx::AssertStmtNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tirx::WhileNode, ReprPrintTIR);
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tirx::AllocBuffer>( //
"", [](tirx::AllocBuffer stmt, AccessPath p, IRDocsifier d) -> Doc {
tirx::Buffer buffer = stmt->buffer;
AccessPath buffer_p = p->Attr("buffer");
Frame frame = d->frames.back();
// Define buffer's data var inline as buffer.data
if (!d->IsVarDefined(buffer->data)) {
d->Define(buffer->data, frame, [buffer, buffer_p, d]() {
return d->AsDoc<ExprDoc>(buffer, buffer_p)->Attr("data");
});
}
// Build simplified T.alloc_buffer(shape, dtype, scope=...) call.
// Only print shape, dtype, scope (and annotations if non-empty).
ffi::Array<ExprDoc> args;
ffi::Array<ffi::String> kwargs_keys;
ffi::Array<ExprDoc> kwargs_values;
// shape (positional)
{
int n = buffer->shape.size();
ffi::Array<ExprDoc> shape_docs;
shape_docs.reserve(n);
AccessPath shape_p = buffer_p->Attr("shape");
for (int i = 0; i < n; ++i) {
PrimExpr e = buffer->shape[i];
AccessPath e_p = shape_p->ArrayItem(i);
if (!d->IsVarDefined(e) && e->IsInstance<tirx::VarNode>()) {
ExprDoc lhs = DefineVar(Downcast<tirx::Var>(e), frame, d);
lhs->source_paths.push_back(e_p);
frame->stmts.push_back(
AssignDoc(lhs, PrintVarCreation(Downcast<tirx::Var>(e), e_p, d), std::nullopt));
}
shape_docs.push_back(d->AsDoc<ExprDoc>(e, e_p));
}
args.push_back(TupleDoc(shape_docs));
}
// dtype (positional, skip if default float32)
if (buffer->dtype != d->cfg->buffer_dtype) {
args.push_back(LiteralDoc::DataType(buffer->dtype, buffer_p->Attr("dtype")));
}
// scope (keyword, skip if "global")
{
ffi::String scope = buffer.scope();
if (scope != "global") {
kwargs_keys.push_back("scope");
kwargs_values.push_back(LiteralDoc::Str(
scope, buffer_p->Attr("data")->Attr("type_annotation")->Attr("storage_scope")));
}
}
// annotations (keyword, skip if empty)
if (!stmt->annotations.empty()) {
kwargs_keys.push_back("annotations");
kwargs_values.push_back(d->AsDoc<ExprDoc>(stmt->annotations, p->Attr("annotations")));
}
ExprDoc rhs = TIR(d, "alloc_buffer")->Call(args, kwargs_keys, kwargs_values);
ExprDoc lhs = DefineBuffer(stmt->buffer, frame, d);
return AssignDoc(lhs, rhs, std::nullopt);
});
TVM_SCRIPT_REPR(tirx::AllocBufferNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tirx::DeclBufferNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tirx::SeqStmtNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tirx::IfThenElseNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tirx::EvaluateNode, ReprPrintTIR);
} // namespace printer
} // namespace script
} // namespace tvm