blob: b2e091f38019fa16d45310c25d19ed9b50100b8f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./utils.h"
namespace tvm {
namespace script {
namespace printer {
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::For>("", [](tir::For loop, AccessPath loop_p, IRDocsifier d) -> Doc {
// Step 1. Check syntactic sugar: `T.grid`
std::vector<const tir::ForNode*> grid;
std::unordered_set<const tir::VarNode*> grid_loop_vars;
auto f_var_dep = [&grid_loop_vars](const PrimExpr& e) -> bool {
return tir::UsesVar(e, [&grid_loop_vars](const tir::VarNode* v) -> bool { //
return grid_loop_vars.count(v);
});
};
if (d->cfg->syntax_sugar) {
for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as<tir::ForNode>()) {
ICHECK(l->loop_var->dtype == l->min->dtype);
ICHECK(l->loop_var->dtype == l->extent->dtype);
if (l->kind != tir::ForKind::kSerial || //
!tir::is_zero(l->min) || //
!l->annotations.empty() || //
!l->HasTrivialStep() || f_var_dep(l->extent)) {
break;
}
grid.push_back(l);
grid_loop_vars.insert(l->loop_var.get());
}
}
With<TIRFrame> f(d, loop);
// Step 2. Construct `T.grid`
if (grid.size() > 1) {
int n = grid.size();
ffi::Array<ExprDoc> lhs;
ffi::Array<ExprDoc> rhs;
lhs.reserve(n);
rhs.reserve(n);
for (int i = 0; i < n; ++i) {
const tir::ForNode* loop = grid[i];
lhs.push_back(DefineVar(loop->loop_var, *f, d));
rhs.push_back(d->AsDoc<ExprDoc>(loop->extent, loop_p->Attr("extent")));
loop_p = loop_p->Attr("body");
}
AsDocBody(grid.back()->body, loop_p, (*f).get(), d);
return ForDoc(TupleDoc(lhs), TIR(d, "grid")->Call(rhs), (*f)->stmts);
}
// Step 3. If not `T.grid`, print loop kind accordingly
ExprDoc lhs = DefineVar(loop->loop_var, *f, d);
ffi::Optional<ExprDoc> min = std::nullopt;
ffi::Optional<ExprDoc> max = std::nullopt;
ffi::Optional<ExprDoc> annotations = std::nullopt;
ffi::Optional<ExprDoc> thread = std::nullopt;
if (tir::is_zero(loop->min) && loop->HasTrivialStep()) {
max = d->AsDoc<ExprDoc>(loop->extent, loop_p->Attr("extent"));
} else {
min = d->AsDoc<ExprDoc>(loop->min, loop_p->Attr("min"));
max = d->AsDoc<ExprDoc>(loop->min + loop->extent, loop_p->Attr("extent"));
}
if (!loop->annotations.empty()) {
annotations = d->AsDoc<ExprDoc>(loop->annotations, loop_p->Attr("annotations"));
}
bool use_range_sugar = false;
ExprDoc prefix{ffi::UnsafeInit()};
if (loop->kind == tir::ForKind::kSerial) {
if (loop->annotations.empty()) {
prefix = IdDoc("range");
use_range_sugar = true;
} else {
prefix = TIR(d, "serial");
}
} else if (loop->kind == tir::ForKind::kParallel) {
prefix = TIR(d, "parallel");
} else if (loop->kind == tir::ForKind::kUnrolled) {
prefix = TIR(d, "unroll");
} else if (loop->kind == tir::ForKind::kVectorized) {
prefix = TIR(d, "vectorized");
} else if (loop->kind == tir::ForKind::kThreadBinding) {
prefix = TIR(d, "thread_binding");
thread = LiteralDoc::Str(loop->thread_binding.value()->thread_tag,
loop_p->Attr("thread_binding"));
} else {
LOG(FATAL) << "ValueError: Unknown ForKind: " << tir::ForKind2String(loop->kind);
}
ffi::Array<ExprDoc> args;
ffi::Array<ffi::String> kwargs_keys;
ffi::Array<ExprDoc> kwargs_values;
if (min.defined()) {
args.push_back(min.value());
}
if (max.defined()) {
args.push_back(max.value());
}
if (thread.defined()) {
kwargs_keys.push_back("thread");
kwargs_values.push_back(thread.value());
}
if (annotations.defined()) {
kwargs_keys.push_back("annotations");
kwargs_values.push_back(annotations.value());
}
if (!loop->HasTrivialStep()) {
ExprDoc step = d->AsDoc<ExprDoc>(*loop->step, loop_p->Attr("step"));
if (use_range_sugar) {
args.push_back(step);
} else {
kwargs_keys.push_back("step");
kwargs_values.push_back(step);
}
}
ExprDoc rhs = prefix->Call(args, kwargs_keys, kwargs_values);
AsDocBody(loop->body, loop_p->Attr("body"), (*f).get(), d);
return ForDoc(lhs, rhs, (*f)->stmts);
});
TVM_SCRIPT_REPR(tir::ForNode, ReprPrintTIR);
} // namespace printer
} // namespace script
} // namespace tvm