blob: dba13cfbbcf19d70f259d9b4b871e6e239187a89 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file ir_utils.cc
* \brief Helper functions to construct and compose IR nodes.
*/
#include "ir_utils.h"
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_solver.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
#include <utility>
namespace tvm {
namespace tir {
Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
// use reverse iteration
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
Stmt s = *ri;
if (const auto* for_ = s.as<ForNode>()) {
auto n = ffi::make_object<ForNode>(*for_);
ICHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (const auto* let = s.as<LetStmtNode>()) {
auto n = ffi::make_object<LetStmtNode>(*let);
ICHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (const auto* attr = s.as<AttrStmtNode>()) {
auto n = ffi::make_object<AttrStmtNode>(*attr);
ICHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (const auto* ite = s.as<IfThenElseNode>()) {
auto n = ffi::make_object<IfThenElseNode>(*ite);
ICHECK(is_no_op(n->then_case));
ICHECK(!n->else_case);
n->then_case = body;
body = Stmt(n);
} else if (const auto* seq = s.as<SeqStmtNode>()) {
auto n = ffi::make_object<SeqStmtNode>(*seq);
ICHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1]));
n->seq.Set(n->size() - 1, body);
body = Stmt(n);
} else if (const auto* assert_ = s.as<AssertStmtNode>()) {
auto n = ffi::make_object<AssertStmtNode>(*assert_);
ICHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (const auto* alloc = s.as<AllocateNode>()) {
auto n = ffi::make_object<AllocateNode>(*alloc);
ICHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (const auto* alloc = s.as<AllocateConstNode>()) {
auto n = ffi::make_object<AllocateConstNode>(*alloc);
ICHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (const auto* decl_buffer = s.as<DeclBufferNode>()) {
auto n = ffi::make_object<DeclBufferNode>(*decl_buffer);
ICHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else {
LOG(FATAL) << "not supported nest type";
}
}
return body;
}
Stmt MergeNest(const std::vector<std::vector<Stmt>>& nest, Stmt body) {
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
body = MergeNest(*ri, body);
}
return body;
}
class IRConvertSSA final : public StmtExprMutator {
public:
PrimFunc VisitPrimFunc(PrimFunc func) {
std::vector<ScopedRedefine> redefines;
// Remap parameters, if they were used in another function
auto params = func->params.Map([&](const tir::Var& var) -> tir::Var {
if (defined_.count(var.get())) {
const ScopedRedefine& redefine = redefines.emplace_back(this, var);
return redefine.new_var;
} else {
defined_.insert(var.get());
return var;
}
});
// Remap implicitly defined buffer parameters
{
std::unordered_set<const VarNode*> defined_params;
for (const auto& var : func->params) {
defined_params.insert(var.get());
}
for (const auto& [var, buffer] : func->buffer_map) {
static_cast<void>(var); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
auto check_expr = [&](const PrimExpr& expr) {
auto* var_ptr = expr.as<VarNode>();
if (!var_ptr) return;
if (defined_params.count(var_ptr)) return;
if (defined_.count(var_ptr)) {
auto var = ffi::GetRef<Var>(var_ptr);
redefines.emplace_back(this, var);
} else {
defined_.insert(var_ptr);
}
};
for (const auto& dim : buffer->shape) {
check_expr(dim);
}
for (const auto& stride : buffer->strides) {
check_expr(stride);
}
check_expr(buffer->elem_offset);
}
}
// Update the buffer map, based on the redefined parameters
auto buffer_map = [&]() {
ffi::Map<Var, Buffer> buffer_map;
bool made_change = false;
for (const auto& [var, buffer] : func->buffer_map) {
auto new_var = GetRemappedVar(var);
if (defined_.count(buffer->data.get())) {
redefines.emplace_back(this, buffer->data);
} else {
defined_.insert(buffer->data.get());
}
auto new_buf = GetRemappedBuffer(buffer);
made_change = made_change || !var.same_as(new_var) || !buffer.same_as(new_buf);
buffer_map.Set(new_var, new_buf);
}
if (made_change) {
return buffer_map;
} else {
return func->buffer_map;
}
}();
auto attrs = [&]() -> DictAttrs {
if (!func->attrs.defined()) {
return DictAttrs();
}
ffi::Map<ffi::String, ffi::Any> dict;
bool made_change = false;
for (const auto& [key, old_value] : func->attrs->dict) {
auto value = old_value;
if (auto* expr = value.as<PrimExprNode>()) {
value = VisitExpr(ffi::GetRef<PrimExpr>(expr));
} else if (auto* stmt = value.as<StmtNode>()) {
value = VisitStmt(ffi::GetRef<Stmt>(stmt));
}
made_change = made_change || !value.same_as(old_value);
dict.Set(key, value);
}
if (made_change) {
return DictAttrs(dict);
} else {
return func->attrs;
}
}();
auto body = VisitStmt(func->body);
// If anything changed, update the returned function
if (!params.same_as(func->params) || !buffer_map.same_as(func->buffer_map) ||
!attrs.same_as(func->attrs) || !body.same_as(func->body)) {
func = PrimFunc(params, body, func->ret_type, buffer_map, attrs);
}
// Pop the redefines in reverse order of creation
while (redefines.size()) {
redefines.pop_back();
}
function_scope_var_remap_.clear();
return func;
}
PrimExpr VisitExpr_(const VarNode* op) final { return GetRemappedVar(ffi::GetRef<Var>(op)); }
PrimExpr VisitExpr_(const LetNode* op) final {
const Var& v = op->var;
if (defined_.count(v.get())) {
PrimExpr value = this->VisitExpr(op->value);
ScopedRedefine redefine(this, v);
PrimExpr body = this->VisitExpr(op->body);
return Let(redefine.new_var, value, body);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitExpr_(op);
}
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto output = VisitBufferAccess(std::move(node));
return output;
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto output = VisitBufferAccess(std::move(node));
return output;
}
Stmt VisitStmt_(const DeclBufferNode* op) final {
DeclBuffer decl = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
Buffer new_buffer = GetRemappedBuffer(decl->buffer);
if (!new_buffer.same_as(decl->buffer)) {
decl.CopyOnWrite()->buffer = std::move(new_buffer);
}
return decl;
}
Stmt VisitStmt_(const BlockNode* op) final {
Block block = ffi::GetRef<Block>(op);
// The BlockNode is the point of definition for the IterVar
// instances. These re-defines must be present before visiting
// the body of the BlockNode.
std::vector<ScopedRedefine> redefines;
ffi::Array<IterVar> iter_vars = op->iter_vars.Map([&](IterVar iter_var) {
if (defined_.count(iter_var->var.get())) {
redefines.emplace_back(this, iter_var->var);
iter_var.CopyOnWrite()->var = redefines.back().new_var;
} else {
defined_.insert(iter_var->var.get());
}
return iter_var;
});
ffi::Array<BufferRegion> reads =
block->reads.Map([&](const auto& region) { return VisitBufferAccess(region); });
ffi::Array<BufferRegion> writes =
block->writes.Map([&](const auto& region) { return VisitBufferAccess(region); });
if (!reads.same_as(block->reads) || !writes.same_as(block->writes) ||
!iter_vars.same_as(op->iter_vars)) {
auto write_ptr = block.CopyOnWrite();
write_ptr->reads = reads;
write_ptr->writes = writes;
write_ptr->iter_vars = iter_vars;
}
Stmt output = Downcast<Block>(StmtExprMutator::VisitStmt_(block.get()));
while (redefines.size()) redefines.pop_back();
return output;
}
template <typename Node>
Node VisitBufferAccess(Node node) {
Buffer new_buf = GetRemappedBuffer(node->buffer);
if (!new_buf.same_as(node->buffer)) {
auto writer = node.CopyOnWrite();
writer->buffer = new_buf;
}
return node;
}
Var GetRemappedVar(Var var) {
if (auto it = scope_.find(var.get()); it != scope_.end() && it->second.size()) {
return it->second.back();
} else if (auto it = function_scope_var_remap_.find(var.get());
it != function_scope_var_remap_.end()) {
return it->second;
} else {
return var;
}
}
Buffer GetRemappedBuffer(Buffer buf) {
// Determine the buffer var that should be in the updated buffer,
// given the current scope. If no redefines are present, then the
// buffer var is unchanged.
Var new_buffer_var = GetRemappedVar(buf->data);
PrimExpr elem_offset = VisitExpr(buf->elem_offset);
auto visit_expr = [this](const PrimExpr& expr) { return VisitExpr(expr); };
ffi::Array<PrimExpr> shape = buf->shape.Map(visit_expr);
ffi::Array<PrimExpr> strides = buf->strides.Map(visit_expr);
// If no mapping is required, return the original buffer.
if (new_buffer_var.same_as(buf->data) && elem_offset.same_as(buf->elem_offset) &&
shape.same_as(buf->shape) && strides.same_as(buf->strides)) {
return buf;
}
// If the current scope already has a mapping of this buffer, use
// the mapped buffer.
auto key = buf.get();
std::vector<Buffer>& buffers = buf_remap_[key];
if (buffers.size() && buffers.back()->data.same_as(new_buffer_var)) {
return buffers.back();
}
// Otherwise, make and return a new buffer object that uses the
// new buffer, pushing it onto the scoped stack of existing
// buffers. This will be popped when the new_buffer_var
// redefinition is popped.
Buffer new_buf = buf;
{
auto write_ptr = new_buf.CopyOnWrite();
write_ptr->data = new_buffer_var;
write_ptr->shape = shape;
write_ptr->strides = strides;
write_ptr->elem_offset = elem_offset;
}
buffers.push_back(new_buf);
return new_buf;
}
Stmt VisitStmt_(const LetStmtNode* op) final {
const Var& v = op->var;
if (defined_.count(v.get())) {
PrimExpr value = this->VisitExpr(op->value);
ScopedRedefine redefine(this, v);
Stmt body = this->VisitStmt(op->body);
return LetStmt(redefine.new_var, value, body);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const ForNode* op) final {
const Var& v = op->loop_var;
if (defined_.count(v.get())) {
ScopedRedefine redefine(this, v);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
return For(redefine.new_var, op->min, op->extent, op->kind, op->body, op->thread_binding,
op->annotations);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const AllocateNode* op) final {
const Var& v = op->buffer_var;
if (defined_.count(v.get())) {
ScopedRedefine redefine(this, v);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
return Allocate(redefine.new_var, op->dtype, op->extents, op->condition, op->body,
op->annotations);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (const IterVarNode* iter_var = op->node.as<IterVarNode>()) {
Range dom = iter_var->dom;
if (dom.defined()) {
auto min = VisitExpr(dom->min);
auto extent = VisitExpr(dom->extent);
if (!min.same_as(iter_var->dom->min) || !extent.same_as(iter_var->dom->extent)) {
dom = Range::FromMinExtent(min, extent);
}
}
Var var = iter_var->var;
bool delayed_define = false;
if (auto it = function_scope_var_remap_.find(var.get());
it != function_scope_var_remap_.end()) {
var = it->second;
} else if (defined_.count(var.get())) {
Var new_var = [&]() {
if (var->type_annotation.defined()) {
return Var(var->name_hint, var->type_annotation);
} else {
return Var(var->name_hint, var->dtype);
}
}();
function_scope_var_remap_.insert({var.get(), new_var});
var = new_var;
} else {
// The AttrStmt refers to an undefined variable. This is
// allowed for some attributes, such as
// "pragma_parallel_launch_point", which annotates a variable
// that is about to occur in a ForNode. In these cases, the
// ForNode and the AttrStmt must continue using the same
// variable defintion.
//
// However, other AttrStmt, such as "thread_extent", act as
// points of definition for the variable they annotate. If
// the variable has not been defined after visiting the body,
// we should mark it as defined before exiting. This ensures
// correct de-duplication between multiple functions.
//
// This implementation may be simplified in the future by
// moving "pragma_parallel_launch_point" to be an annotation
// on the `ForNode`, rather than an `AttrStmt`.
delayed_define = true;
}
IterVar new_iter_var;
if (dom.same_as(iter_var->dom) && var.same_as(iter_var->var)) {
new_iter_var = ffi::GetRef<IterVar>(iter_var);
} else {
new_iter_var = IterVar(dom, var, iter_var->iter_type, iter_var->thread_tag, iter_var->span);
}
auto value = VisitExpr(op->value);
auto body = VisitStmt(op->body);
Stmt output;
if (new_iter_var.get() == iter_var && body.same_as(op->body) && value.same_as(op->value)) {
output = ffi::GetRef<Stmt>(op);
} else {
output = AttrStmt(new_iter_var, op->attr_key, value, body, iter_var->span);
}
if (delayed_define) {
if (!defined_.count(var.get())) {
function_scope_var_remap_.insert({var.get(), var});
defined_.insert(var.get());
}
}
return output;
} else if (const VarNode* v = op->node.as<VarNode>()) {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
if (scope_.count(v) && scope_[v].size() != 0) {
return AttrStmt(scope_[v].back(), op->attr_key, op->value, op->body);
} else {
return stmt;
}
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
private:
struct ScopedRedefine {
ScopedRedefine(IRConvertSSA* parent, Var old_var) : parent(parent), old_var(old_var) {
bool is_size_var = old_var->IsInstance<SizeVarNode>();
if (old_var->type_annotation.defined()) {
if (is_size_var) {
new_var = SizeVar(old_var->name_hint, old_var->type_annotation);
} else {
new_var = Var(old_var->name_hint, old_var->type_annotation);
}
} else {
if (is_size_var) {
new_var = SizeVar(old_var->name_hint, old_var->dtype);
} else {
new_var = Var(old_var->name_hint, old_var->dtype);
}
}
parent->scope_[old_var.get()].push_back(new_var);
}
~ScopedRedefine() {
if (parent) {
parent->scope_[old_var.get()].pop_back();
for (auto& kv : parent->buf_remap_) {
std::vector<Buffer>& buffers = kv.second;
if (buffers.size() && (buffers.back()->data.get() == new_var.get())) {
buffers.pop_back();
}
}
}
}
ScopedRedefine& operator=(const ScopedRedefine&) = delete;
ScopedRedefine(const ScopedRedefine&) = delete;
ScopedRedefine& operator=(ScopedRedefine&& other) {
swap(other);
return *this;
}
ScopedRedefine(ScopedRedefine&& other) { swap(other); }
void swap(ScopedRedefine& other) {
std::swap(parent, other.parent);
std::swap(old_var, other.old_var);
std::swap(new_var, other.new_var);
}
IRConvertSSA* parent{nullptr};
Var old_var;
Var new_var;
};
std::unordered_map<const VarNode*, std::vector<Var>> scope_;
std::unordered_set<const VarNode*> defined_;
std::unordered_map<const BufferNode*, std::vector<Buffer>> buf_remap_;
std::unordered_map<const VarNode*, Var> function_scope_var_remap_;
};
Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); }
ffi::String GetPtrStorageScope(Var buffer_var) {
const auto* ptr_type = buffer_var->type_annotation.as<PointerTypeNode>();
ICHECK(ptr_type) << "The provided variable is not of pointer type";
return ptr_type->storage_scope;
}
ffi::Array<PrimExpr> GetBufferAllocationShape(const Buffer& buffer) {
ffi::Array<PrimExpr> alloc_shape = buffer->shape;
if (buffer->strides.size()) {
ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
ICHECK(
arith::Analyzer().CanProveEqual(floormod(buffer->strides[i - 1], buffer->strides[i]), 0));
alloc_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
}
}
return alloc_shape;
}
ffi::Array<PrimExpr> ConvertIndices(const MatchBufferRegion& match_buffer,
const ffi::Array<PrimExpr>& indices) {
const Buffer& target = match_buffer->buffer;
const BufferRegion& source = match_buffer->source;
ICHECK_EQ(indices.size(), target->shape.size());
arith::Analyzer analyzer;
ffi::Array<PrimExpr> result;
result.reserve(source->region.size());
size_t offset = source->region.size() - indices.size();
for (size_t i = 0; i < offset; ++i) {
const Range& range = source->region[i];
ICHECK(analyzer.CanProve(range->extent == 1));
result.push_back(range->min);
}
for (size_t i = 0; i < indices.size(); ++i) {
const Range& range = source->region[i + offset];
const PrimExpr& index = indices[i];
result.push_back(range->min + index);
}
return result;
}
Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region) {
const Buffer& target = match_buffer->buffer;
const BufferRegion& source = match_buffer->source;
ICHECK_EQ(region.size(), target->shape.size());
arith::Analyzer analyzer;
Region result;
result.reserve(source->region.size());
size_t offset = source->region.size() - region.size();
for (size_t i = 0; i < offset; ++i) {
const Range& source_range = source->region[i];
ICHECK(analyzer.CanProve(source_range->extent == 1));
result.push_back(Range::FromMinExtent(source_range->min, 1));
}
for (size_t i = 0; i < region.size(); ++i) {
const Range& source_range = source->region[i + offset];
const Range& target_range = region[i];
result.push_back(
Range::FromMinExtent(source_range->min + target_range->min, target_range->extent));
}
return result;
}
ffi::Optional<arith::IntConstraints> ConditionalBoundsContext::TrySolveCondition() {
// extract equations and related vars from condition expression.
// currently only extract simple integral equations which could be solvable.
arith::Analyzer analyzer;
PrimExpr condition = analyzer.Simplify(condition_);
if (is_const_int(condition)) {
return std::nullopt;
}
ffi::Array<PrimExpr> equations;
ffi::Array<Var> vars;
std::function<void(const PrimExpr&)> fvisit = [&equations, &vars, &fvisit](const PrimExpr& e) {
if (e->IsInstance<GENode>() || e->IsInstance<GTNode>() || e->IsInstance<LENode>() ||
e->IsInstance<LTNode>() || e->IsInstance<EQNode>() || e->IsInstance<NENode>()) {
bool is_simple = true;
std::vector<Var> cand_vars;
PostOrderVisit(e, [&cand_vars, &is_simple, &e](const ObjectRef& obj) {
if (obj.same_as(e)) {
return;
} else if (const VarNode* var = obj.as<VarNode>()) {
if (var->dtype.is_int() || var->dtype.is_uint()) {
cand_vars.push_back(ffi::GetRef<Var>(var));
}
} else {
is_simple &= obj->IsInstance<AddNode>() || obj->IsInstance<SubNode>() ||
obj->IsInstance<MulNode>() || obj->IsInstance<FloorDivNode>() ||
obj->IsInstance<FloorModNode>() || obj->IsInstance<IntImmNode>();
}
});
if (is_simple && !cand_vars.empty()) {
for (const Var& new_var : cand_vars) {
if (!std::any_of(vars.begin(), vars.end(),
[&new_var](const Var& v) { return v.same_as(new_var); })) {
vars.push_back(new_var);
}
}
equations.push_back(Downcast<PrimExpr>(e));
}
} else if (e->IsInstance<AndNode>()) {
And op = Downcast<And>(e);
fvisit(op->a);
fvisit(op->b);
} else if (e->IsInstance<CallNode>()) {
Call op = Downcast<Call>(e);
if (op->op.same_as(builtin::likely())) {
fvisit(op->args[0]);
}
}
};
fvisit(condition);
if (equations.empty() || vars.empty()) {
return std::nullopt;
}
// build dom ranges for related vars
ffi::Map<Var, Range> ranges;
for (const Var& v : vars) {
arith::IntSet dom;
auto relax_it = relax_map_->find(v.get());
if (relax_it != relax_map_->end()) {
dom = relax_it->second;
} else {
auto hint_it = hint_map_->find(v.get());
if (hint_it != hint_map_->end()) {
dom = hint_it->second;
}
}
if (dom.defined()) {
ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer.Simplify(dom.max() - dom.min() + 1)));
}
}
// solve constraints
arith::IntConstraints constraint(vars, ranges, equations);
arith::IntConstraints result = arith::SolveInequalitiesToRange(constraint);
if (!result->relations.empty()) {
return std::nullopt;
}
return result;
}
ConditionalBoundsContext::ConditionalBoundsContext(
const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* relax_map,
std::unordered_map<const VarNode*, arith::IntSet>* hint_map,
std::vector<PrimExpr>* pending_conditions)
: condition_(condition),
relax_map_(relax_map),
hint_map_(hint_map),
pending_conditions_(pending_conditions),
origin_pending_conditions_num_(pending_conditions->size()) {}
void ConditionalBoundsContext::EnterWithScope() {
ffi::Optional<arith::IntConstraints> constraints = TrySolveCondition();
if (!constraints.defined()) {
// fail to process the condition, add to unresolved
pending_conditions_->push_back(condition_);
return;
}
// update solved var ranges
for (const auto& kv : constraints.value()->ranges) {
const VarNode* var = kv.first.get();
arith::IntSet new_dom = arith::IntSet::FromRange(kv.second);
auto relax_it = relax_map_->find(var);
if (relax_it != relax_map_->end()) {
// this is a bound for relaxed var
origin_map_.emplace(var, relax_it->second);
relax_it->second = arith::Intersect({relax_it->second, new_dom});
} else {
// this is a bound for free var
auto hint_it = hint_map_->find(var);
if (hint_it != hint_map_->end()) {
origin_map_.emplace(var, hint_it->second);
hint_it->second = arith::Intersect({hint_it->second, new_dom});
} else {
origin_map_.emplace(var, arith::IntSet::Nothing());
hint_map_->insert(hint_it, {var, new_dom});
}
}
}
}
void ConditionalBoundsContext::ExitWithScope() {
pending_conditions_->resize(origin_pending_conditions_num_);
for (const auto& p : origin_map_) {
const auto* var = p.first;
auto relax_it = relax_map_->find(var);
if (relax_it != relax_map_->end()) {
// recover bound for relaxed var
relax_it->second = p.second;
} else {
// recover bound for free var
auto hint_it = hint_map_->find(var);
ICHECK(hint_it != hint_map_->end());
if (p.second.IsNothing()) {
hint_map_->erase(hint_it);
} else {
hint_it->second = p.second;
}
}
}
}
std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const AttrStmtNode* op) {
ICHECK(op && op->attr_key == tir::attr::async_wait_queue_scope);
auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner && inner->attr_key == tir::attr::async_wait_inflight_count);
return std::make_pair(op->value, inner->value);
}
/*! \brief Collect storage alignment information from annotations. */
class StorageAlignCollector : public StmtVisitor {
private:
friend std::unordered_map<Var, StorageAlignAnnotation> CollectStorageAlignAnnotation(
const Stmt& body);
/*! \brief For s-stir, the alignment annotations reside in block annotations. */
void VisitStmt_(const BlockNode* op) final {
auto it = op->annotations.find(attr::buffer_dim_align);
if (it != op->annotations.end()) {
auto storage_align_annotation = Downcast<StorageAlignAnnotation>((*it).second);
for (const auto& storage_align_tuple : storage_align_annotation) {
int buffer_index = storage_align_tuple.get<0>();
const Buffer& buffer = op->writes[buffer_index]->buffer;
storage_align_[buffer->data].push_back(storage_align_tuple);
}
}
StmtVisitor::VisitStmt_(op);
}
/*! \brief For lowered tir, the alignment annotations reside in allocate annotations. */
void VisitStmt_(const AllocateNode* op) final {
auto it = op->annotations.find(attr::buffer_dim_align);
if (it != op->annotations.end()) {
auto storage_align_annotation = Downcast<StorageAlignAnnotation>((*it).second);
for (const auto& storage_align_tuple : storage_align_annotation) {
int buffer_index = storage_align_tuple.get<0>();
// the first buffer idx info is meaningless for allocate
// stmt and should set as negative intentionally.
ICHECK_EQ(buffer_index, -1);
storage_align_[op->buffer_var].push_back(storage_align_tuple);
}
}
StmtVisitor::VisitStmt_(op);
}
/*! \brief The map from buffer var to its storage alignment information. */
std::unordered_map<Var, StorageAlignAnnotation> storage_align_;
};
std::unordered_map<Var, StorageAlignAnnotation> CollectStorageAlignAnnotation(const Stmt& body) {
StorageAlignCollector collector;
collector(body);
return std::move(collector.storage_align_);
}
int Stoi(const std::string& str) {
try {
return std::stoi(str);
} catch (std::invalid_argument& e) {
LOG(FATAL) << "Cannot convert \"" << str << "\" to int";
throw;
}
}
std::pair<int32_t, int32_t> GetWmmaFragmentDimSize(const std::string& shape_str,
const std::string& scope) {
size_t m, n, k;
size_t last_pos = 0, pos = 0;
pos = shape_str.find(", ", last_pos);
m = Stoi(shape_str.substr(last_pos, pos - last_pos));
last_pos = pos + 2;
pos = shape_str.find(", ", last_pos);
n = Stoi(shape_str.substr(last_pos, pos - last_pos));
last_pos = pos + 2;
k = Stoi(shape_str.substr(last_pos, shape_str.length() - last_pos));
if (scope == "wmma.matrix_a") {
return std::pair<int32_t, int32_t>(m, k);
} else if (scope == "wmma.matrix_b") {
return std::pair<int32_t, int32_t>(k, n);
} else if (scope == "wmma.accumulator") {
return std::pair<int32_t, int32_t>(m, n);
}
return std::pair<int32_t, int32_t>(0, 0);
}
std::optional<bool> IsHostFunc(const PrimFunc& func) {
if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) {
return true;
} else if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
return target.value()->HasKey("cpu");
} else {
return std::nullopt;
}
}
namespace transform {
Pass ConvertSSA() {
auto pass_func = [](IRModule mod, PassContext ctx) {
tir::IRConvertSSA converter;
ffi::Map<GlobalVar, BaseFunc> functions;
bool made_change = false;
for (auto [gvar, base_func] : mod->functions) {
if (auto* ptr = base_func.as<tir::PrimFuncNode>()) {
auto updated = converter.VisitPrimFunc(ffi::GetRef<tir::PrimFunc>(ptr));
if (!updated.same_as(base_func)) {
made_change = true;
base_func = updated;
}
}
functions.Set(gvar, base_func);
}
if (made_change) {
mod.CopyOnWrite()->functions = std::move(functions);
}
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0, "tir.ConvertSSA", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.ConvertSSA", ConvertSSA);
}
} // namespace transform
} // namespace tir
} // namespace tvm