blob: dba8e25ebec7a5bf6c96c0791b0c25aa378de683 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file compact_buffer_region.cc
* \brief Compact the buffer size into its exact need.
*/
#include <tvm/arith/int_set.h>
#include <tvm/arith/int_solver.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <numeric>
#include <stack>
#include "../../support/arena.h"
#include "../../support/nd_int_set.h"
#include "../../support/utils.h"
#include "../../tir/transform/ir_utils.h"
#include "../schedule/utils.h"
namespace tvm {
namespace s_tir {
using namespace tvm::tir;
using support::NDIntSet;
/*! \brief a more constrained bound estimate for n-dimentional int set */
NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
const std::unordered_map<const VarNode*, arith::IntSet>& dom_map,
arith::Analyzer* analyzer) {
std::unordered_map<Var, Range, ObjectPtrHash, ObjectPtrEqual> var_dom;
for (const auto& it : dom_map) {
var_dom[ffi::GetRef<Var>(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0));
}
ffi::Optional<ffi::Array<arith::IntSet>> eval_res =
arith::EstimateRegionUpperBound(region, var_dom, predicate, analyzer);
if (eval_res.defined()) {
return NDIntSet(eval_res.value().begin(), eval_res.value().end());
}
return support::NDIntSetEval(support::NDIntSetFromRegion(region), dom_map);
}
/*!
* \brief Collect buffer aliasing information.
*/
class Var2BufferCollector : public StmtExprVisitor {
public:
/*! \brief Map the buffer var to all aliased buffers. */
std::unordered_map<Var, std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>> var2buffer_;
private:
void VisitStmt_(const BufferStoreNode* op) final {
var2buffer_[op->buffer->data].insert(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const BufferLoadNode* op) final {
var2buffer_[op->buffer->data].insert(op->buffer);
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const SBlockNode* op) final {
for (const Buffer& buffer : op->alloc_buffers) {
var2buffer_[buffer->data].insert(buffer);
}
for (const MatchBufferRegion& region : op->match_buffers) {
var2buffer_[region->buffer->data].insert(region->buffer);
var2buffer_[region->source->buffer->data].insert(region->source->buffer);
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const DeclBufferNode* op) final {
var2buffer_[op->buffer->data].insert(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}
};
/*!
* \brief Collect the access region of each buffer.
* \note The param buffer regions will not be collected.
*/
class BufferAccessRegionCollector : public StmtExprVisitor {
public:
static std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> Collect(
const PrimFunc& f, bool collect_inbound) {
BufferAccessRegionCollector region_collector(collect_inbound);
// collect buffer var to aliased buffer mapping
Var2BufferCollector var2buffer_collector;
var2buffer_collector(f->body);
std::swap(region_collector.var2buffer_, var2buffer_collector.var2buffer_);
// collect buffer access regions
region_collector(f->body);
return std::move(region_collector.buffer_access_region_);
}
private:
struct BufferAccessInfo {
/*! \brief The buffer. */
Buffer buffer;
/*! \brief The buffer access region, which can be updated during visiting. */
NDIntSet accessed_region;
explicit BufferAccessInfo(const Buffer& buffer, const NDIntSet& region)
: buffer(buffer), accessed_region(region) {}
};
explicit BufferAccessRegionCollector(bool collect_inbound) : collect_inbound_(collect_inbound) {}
/**************** Visitor overload ****************/
void VisitStmt_(const BufferStoreNode* op) final {
VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices));
VisitExpr(op->value);
}
void VisitExpr_(const BufferLoadNode* op) final {
auto explicit_it = explicit_access_annotations_.find(op->buffer);
if (explicit_it != explicit_access_annotations_.end()) {
VisitBufferAccess(explicit_it->second);
} else {
VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices));
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const VarNode* op) final { VisitBufferVar(ffi::GetRef<Var>(op)); }
void VisitStmt_(const ForNode* op) final {
Range loop_range = Range::FromMinExtent(op->min, op->extent);
IterVar iter = op->kind == ForKind::kThreadBinding
? IterVar(Range(), op->loop_var, IterVarType::kThreadIndex,
op->thread_binding.value()->thread_tag)
: IterVar(Range(), op->loop_var, IterVarType::kDataPar);
ancestor_iters_.push_back(iter);
dom_analyzer_.Bind(op->loop_var, loop_range);
dom_map_.emplace(op->loop_var.get(), arith::IntSet::FromRange(loop_range));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(op->loop_var.get());
ancestor_iters_.pop_back();
}
void VisitStmt_(const LetStmtNode* op) final {
StmtExprVisitor::VisitExpr(op->value);
if (arith::IsIndexType(op->value->dtype)) {
dom_analyzer_.Bind(op->var, op->value);
dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value));
}
StmtExprVisitor::VisitStmt(op->body);
if (arith::IsIndexType(op->value->dtype)) {
dom_map_.erase(op->var.get());
}
}
void VisitExpr_(const LetNode* op) final {
StmtExprVisitor::VisitExpr(op->value);
if (arith::IsIndexType(op->value->dtype)) {
dom_analyzer_.Bind(op->var, op->value);
dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value));
}
StmtExprVisitor::VisitExpr(op->body);
if (arith::IsIndexType(op->value->dtype)) {
dom_map_.erase(op->var.get());
}
}
void VisitStmt_(const IfThenElseNode* op) final {
// Visit condition
StmtExprVisitor::VisitExpr(op->condition);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_,
&pending_conditions_);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case) {
// Visit else branch
With<ConditionalBoundsContext> ctx(!op->condition, &dom_map_, &hint_map_,
&pending_conditions_);
StmtExprVisitor::VisitStmt(op->else_case.value());
}
}
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::if_then_else())) {
// Visit condition
StmtExprVisitor::VisitExpr(op->args[0]);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_,
&pending_conditions_);
StmtExprVisitor::VisitExpr(op->args[1]);
}
{
// Visit else branch
With<ConditionalBoundsContext> ctx(!op->args[0], &dom_map_, &hint_map_,
&pending_conditions_);
StmtExprVisitor::VisitExpr(op->args[2]);
}
return;
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const SBlockNode* op) final {
// Step 0. Check there is no init part and block is opaque
ICHECK(!op->init.defined());
ICHECK_EQ(op->iter_vars.size(), 0) << "CompactBufferRegion only works on opaque blocks";
// Step 1. Record and update current read/write region annotations
std::unordered_map<Buffer, std::vector<BufferRegion>, ObjectPtrHash, ObjectPtrEqual>
cur_access_annotations;
for (const BufferRegion& region : op->reads) {
cur_access_annotations[region->buffer].push_back(region);
}
for (const BufferRegion& region : op->writes) {
cur_access_annotations[region->buffer].push_back(region);
}
for (auto& p : cur_access_annotations) {
auto& regions = access_annotations_[p.first];
p.second.swap(regions);
}
// Step 2. Record explicit read/write region annotations
auto record_explicit_region = [&](const ffi::String& attr_key, BufferIndexType index_type) {
auto it = op->annotations.find(attr_key);
if (it != op->annotations.end()) {
ffi::Array<Integer> buffer_indices = Downcast<ffi::Array<Integer>>((*it).second);
for (const auto& index : buffer_indices) {
int buffer_index = index->value;
if (buffer_index >= 0 && buffer_index < static_cast<int>(op->reads.size())) {
const BufferRegion& explicit_region = index_type == BufferIndexType::kRead
? op->reads[buffer_index]
: op->writes[buffer_index];
explicit_access_annotations_[explicit_region->buffer] = explicit_region;
}
}
}
};
record_explicit_region(tir::attr::explicit_read_region, BufferIndexType::kRead);
record_explicit_region(tir::attr::explicit_write_region, BufferIndexType::kWrite);
// Step 3. Record relax position of ancestor_loops_
for (const Buffer& buffer : op->alloc_buffers) {
VisitBufferDef(buffer->data);
}
// Step 4. Visit match buffers
for (const MatchBufferRegion& region : op->match_buffers) {
VisitBufferAccess(region->source);
}
// Step 5. Visit block body recursively
StmtExprVisitor::VisitStmt_(op);
// Step 6. Recover read/write region annotations
for (auto& p : cur_access_annotations) {
auto& regions = access_annotations_[p.first];
if (p.second.empty()) {
access_annotations_.erase(p.first);
} else {
regions.swap(p.second);
}
}
// Step 7. Clear explicit access annotations
explicit_access_annotations_.clear();
// Step 8. Update buffer_access_region_ from relaxed_accesses_ for inner buffers.
for (const Buffer& buffer : op->alloc_buffers) {
ICHECK_EQ(var2buffer_[buffer->data].size(), 1)
<< "Block allocation buffer shoud not be alised";
SimplifyAndNarrowBufferRegionFromNDIntSet(buffer);
}
}
void VisitStmt_(const SBlockRealizeNode* op) final {
With<ConditionalBoundsContext> ctx(op->predicate, &dom_map_, &hint_map_, &pending_conditions_);
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AllocateNode* op) final {
auto it = var2buffer_.find(op->buffer_var);
// Do not make compaction when the buffer def and
// the allocation is not one-to-one with the same dtype.
if (it == var2buffer_.end() || it->second.size() > 1) {
return StmtExprVisitor::VisitStmt_(op);
}
const Buffer& buffer = *it->second.begin();
if (buffer->dtype != op->dtype) {
return StmtExprVisitor::VisitStmt_(op);
}
// Step 0. Record relax position of ancestor_loops_
VisitBufferDef(op->buffer_var);
// Step 1. Visit block body recursively
StmtExprVisitor::VisitStmt(op->body);
// Step 2. Update buffer_access_region_ from relaxed_accesses_ for inner buffers.
SimplifyAndNarrowBufferRegionFromNDIntSet(buffer);
}
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) {
IterVar iter = Downcast<IterVar>(op->node);
ancestor_iters_.push_back(iter);
Range dom = iter->dom;
if (!dom.defined()) { // dom is empty for legacy te schedule
dom = Range::FromMinExtent(make_zero(op->value->dtype), op->value);
}
dom_analyzer_.Bind(iter->var, dom);
dom_map_.emplace(iter->var.get(), arith::IntSet::FromRange(dom));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(iter->var.get());
ancestor_iters_.pop_back();
return;
}
StmtExprVisitor::VisitStmt_(op);
}
/**************** Helper functions ****************/
/*! \brief Record information on the buffer defining point. */
void VisitBufferDef(const Var& buffer_data) {
auto it = buffer_scope_depth_.find(buffer_data);
ICHECK(it == buffer_scope_depth_.end()) << buffer_data << " has duplicate definitions";
buffer_scope_depth_.insert(it, {buffer_data, ancestor_iters_.size()});
}
void VisitBufferAccess(const BufferRegion& buffer_region) {
const Buffer& buffer = buffer_region->buffer;
auto it = buffer_scope_depth_.find(buffer->data);
if (it != buffer_scope_depth_.end()) {
size_t n_ancestor_loops = it->second;
// Step 1. Stop ancestor loop vars out of the allocation block from
// being relaxed unless NeedRelaxThread() is true.
std::vector<arith::IntSet> non_relaxed(n_ancestor_loops);
for (size_t i = 0; i < n_ancestor_loops; ++i) {
const IterVar& iter = ancestor_iters_[i];
const VarNode* v = iter->var.get();
if (NeedRelaxThread(iter, runtime::StorageScope::Create(buffer.scope()))) {
continue;
}
auto dom_it = dom_map_.find(v);
ICHECK(dom_it != dom_map_.end())
<< "Could not find domain for loop variable " << v->name_hint;
non_relaxed[i] = dom_it->second;
dom_map_.erase(dom_it);
}
// Step 2. Relax the access region
auto normalize_pred = [](const PrimExpr& pred) {
if (pred->dtype.is_bool()) return pred;
return pred != make_zero(pred->dtype);
};
PrimExpr predicate = dom_analyzer_.Simplify(
std::accumulate(pending_conditions_.begin(), pending_conditions_.end(), const_true(),
[normalize_pred](const PrimExpr& x, const PrimExpr& y) {
return normalize_pred(x) && normalize_pred(y);
}));
NDIntSet nd_int_set =
NDIntSetEval(buffer_region->region, predicate, dom_map_, &dom_analyzer_);
// Step 3. Restore the non-relaxed ancestor loops domain
for (size_t i = 0; i < n_ancestor_loops; ++i) {
const VarNode* v = ancestor_iters_[i]->var.get();
dom_map_.emplace(v, non_relaxed[i]);
}
// Step 4. Update relaxed_accesses_ dict
auto access_it = relaxed_accesses_.find(buffer);
if (access_it != relaxed_accesses_.end()) {
support::NDIntSetUnionWith(&access_it->second, nd_int_set);
} else {
relaxed_accesses_.insert(access_it, {buffer, nd_int_set});
}
}
}
void VisitBufferVar(const Var& var) {
auto it = var2buffer_.find(var);
if (it == var2buffer_.end()) {
return;
}
for (const Buffer& buffer : it->second) {
auto annotation_it = access_annotations_.find(buffer);
if (annotation_it != access_annotations_.end()) {
// opaque buffer has explicit accessed region annotations
for (const BufferRegion& region : annotation_it->second) {
VisitBufferAccess(region);
}
} else {
VisitBufferAccess(BufferRegion::FullRegion(buffer));
}
}
}
/*! \brief Check whether the thread binding iter should be relaxed with given storage scope. */
static bool NeedRelaxThread(const IterVar& iter, const runtime::StorageScope& scope) {
if (iter->iter_type != IterVarType::kThreadIndex) {
return false;
}
// When there is warp memory
// threadIdx.x must be set to be warp index.
return CanRelaxStorageUnderThread(scope, runtime::ThreadScope::Create((iter->thread_tag)));
}
/*!
* \brief simplify and narrow down the region collected by NDIntSet.
* Update the `relaxed_accesses_` dict. If `collect_inbound_` is true,
* the result region would never exceed the original buffer shape.
*/
void SimplifyAndNarrowBufferRegionFromNDIntSet(const Buffer& buffer) {
auto it = relaxed_accesses_.find(buffer);
ICHECK(it != relaxed_accesses_.end())
<< buffer << " is allocated but not accessed within block scope";
const ffi::Array<PrimExpr>& original_shape = buffer->shape;
const NDIntSet& nd_int_set = it->second;
ffi::Array<Range>& result_region = buffer_access_region_[buffer];
result_region.resize(nd_int_set.size());
for (size_t i = 0; i < nd_int_set.size(); ++i) {
const arith::IntSet& int_set = nd_int_set[i];
Range original =
Range(/*begin=*/make_zero(original_shape[i]->dtype), /*end=*/original_shape[i]);
Range range = int_set.CoverRange(original);
PrimExpr min, extent;
if (collect_inbound_) {
min = dom_analyzer_.Simplify(tvm::max(0, range->min));
extent = range->extent;
// Apply stronger symbolic proof to help us remove symbolic min here.
if (!dom_analyzer_.CanProveLessEqualThanSymbolicShapeValue(extent, original_shape[i])) {
extent = tvm::min(original_shape[i], range->extent);
}
extent = dom_analyzer_.Simplify(extent);
} else {
min = dom_analyzer_.Simplify(range->min);
extent = dom_analyzer_.Simplify(range->extent);
}
// We check the buffer extent is pure and not loop dependent, since loop dependent
// or data dependent allocation is not supported yet. Otherwise we should
// fallback to use original buffer shape.
if (SideEffect(extent) > CallEffectKind::kPure) {
result_region.Set(i, original);
continue;
}
auto is_loop_var = [this](const VarNode* v) {
return std::any_of(ancestor_iters_.begin(), ancestor_iters_.end(),
[v](const IterVar& n) { return n->var.get() == v; });
};
if (UsesVar(extent, is_loop_var)) {
// try estimate a constant upperbound on region's extent
int64_t upperbound = dom_analyzer_.const_int_bound(extent)->max_value;
if (upperbound != arith::ConstIntBound::kPosInf) {
extent = make_const(extent->dtype, upperbound);
} else {
result_region.Set(i, original);
continue;
}
}
result_region.Set(i, Range::FromMinExtent(min, extent));
}
}
/**************** Class members ****************/
/*! \brief Only collect accessed region within original buffer shape bound. */
bool collect_inbound_{true};
/*! \brief The iteration scopes from the current node up to the root. */
std::vector<IterVar> ancestor_iters_;
/*!
* \brief Map each buffer var to the n_ancester_loop. which is the loop depth at the
* define point. ancestor_loops_[0: n_ancester_loop] should not be relaxed when
* we evaluate this buffer's access regions.
*/
std::unordered_map<Var, size_t> buffer_scope_depth_;
/*! \brief Map the buffer var to all aliased buffers. */
std::unordered_map<Var, std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>> var2buffer_;
/*! \brief The map from loop vars to their iter range. */
std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
/*! \brief Extra map from free vars to their iter range hints. */
std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
/*! \brief Unresolved conditions within current scope. */
std::vector<PrimExpr> pending_conditions_;
/*! \brief The analyzer aware of loop domains. */
arith::Analyzer dom_analyzer_;
/*! \brief The map from Buffer to it's relaxed access set. */
std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> relaxed_accesses_;
/*!
* \brief The map from Buffer to it entire access region, used for returning.
* The entire access region should get updated on the buffer's define point
* and we sanity check that every buffer is defined only once.
*/
std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> buffer_access_region_;
/*! \brief The map from Buffer to it's access regions annotated by current block. */
std::unordered_map<Buffer, std::vector<BufferRegion>, ObjectPtrHash, ObjectPtrEqual>
access_annotations_;
/*! \brief The map from Buffer to its explicit access region annotated by the block. */
std::unordered_map<Buffer, BufferRegion, ObjectPtrHash, ObjectPtrEqual>
explicit_access_annotations_;
};
/*! \brief The storage alignment for a dimension */
struct DimAlignInfo {
/*! \brief The factor of the alignment */
int align_factor{0};
/*! \brief The offset of the alignment */
int align_offset{0};
};
struct BufferAllocInfo {
/*! \brief The buffer access region. */
Region region;
/*! \brief The storage alignment information. */
std::vector<DimAlignInfo> dim_aligns;
/*!
* \brief The reallocated buffer with minimal size.
* \note The value if std::nullopt if the buffer do not need reallocate (e.g parameter buffer).
*/
Buffer new_buffer;
};
/*! \brief Reallocate the buffers with minimal region. */
class BufferCompactor : public StmtExprMutator {
public:
explicit BufferCompactor(std::unordered_map<Var, BufferAllocInfo> buffer_info)
: buffer_info_(std::move(buffer_info)) {}
Stmt VisitStmt_(const BufferStoreNode* _op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_op));
BufferStoreNode* op = store.CopyOnWrite();
RewriteBufferAccess(&op->buffer, &op->indices);
return store;
}
PrimExpr VisitExpr_(const BufferLoadNode* _op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_op));
BufferLoadNode* op = load.CopyOnWrite();
RewriteBufferAccess(&op->buffer, &op->indices);
return load;
}
Stmt VisitStmt_(const SBlockNode* op) final {
// Step 0. Check there is no Init part.
ICHECK(!op->init.defined());
// Step 1. Reallocate and rewrite alloc_buffers, also update BufferAllocInfo.
ffi::Array<Buffer> alloc_buffers =
op->alloc_buffers.Map([this](const Buffer& buf) { return RewriteAllocBuffer(buf); });
// Step 2. Recursively rewrite BufferLoad/BufferStore.
SBlock block = Downcast<SBlock>(StmtExprMutator::VisitStmt_(op));
// Step 3. Update block signature.
SBlockNode* n = block.CopyOnWrite();
RewriteBufferRegions(&n->reads);
RewriteBufferRegions(&n->writes);
RewriteMatchBuffers(&n->match_buffers);
n->alloc_buffers = std::move(alloc_buffers);
return block;
}
Stmt VisitStmt_(const DeclBufferNode* op) final {
Buffer new_buffer = RewriteAllocBuffer(op->buffer);
auto n = CopyOnWrite(op);
n->buffer = std::move(new_buffer);
n->body = VisitStmt(op->body);
return DeclBuffer(n);
}
Stmt VisitStmt_(const AllocateNode* op) final {
Allocate allocate = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
auto it = buffer_info_.find(allocate->buffer_var);
if (it == buffer_info_.end()) {
return allocate;
}
// Rewrite allocation shape if the corresponding buffer is in the buffer_info_
// dict and the dtype is consistent, which denotes there are no buffer aliasing
// and the compaction is safe.
const Buffer& new_buffer = it->second.new_buffer;
if (op->dtype != new_buffer->dtype) {
return allocate;
}
ffi::Array<PrimExpr> new_shape = GetBufferAllocationShape(new_buffer);
auto n = allocate.CopyOnWrite();
ICHECK(n->buffer_var.same_as(new_buffer->data));
n->extents = new_shape;
return allocate;
}
Buffer RewriteAllocBuffer(const Buffer& buffer) {
auto it = buffer_info_.find(buffer->data);
if (it != buffer_info_.end()) {
return it->second.new_buffer;
}
return buffer;
}
void RewriteBufferAccess(Buffer* buffer, ffi::Array<PrimExpr>* indices) const {
auto it = buffer_info_.find((*buffer)->data);
if (it == buffer_info_.end()) {
return;
}
const BufferAllocInfo& info = it->second;
ICHECK_EQ(indices->size(), info.region.size());
int ndim = info.region.size();
ffi::Array<PrimExpr> new_indices;
new_indices.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
new_indices.push_back((*indices)[i] - info.region[i]->min);
}
*buffer = info.new_buffer;
*indices = std::move(new_indices);
}
void RewriteBufferRegion(Buffer* buffer, Region* region) const {
auto it = buffer_info_.find((*buffer)->data);
if (it == buffer_info_.end()) {
// Skip if the buffer is parameter
return;
}
const BufferAllocInfo& info = it->second;
ICHECK_EQ(region->size(), info.region.size());
Region new_region;
new_region.reserve(info.region.size());
for (size_t i = 0; i < info.region.size(); ++i) {
const Range& range = (*region)[i];
new_region.push_back(Range::FromMinExtent(range->min - info.region[i]->min, range->extent));
}
*buffer = info.new_buffer;
*region = std::move(new_region);
}
void RewriteBufferRegions(ffi::Array<BufferRegion>* regions) const {
ffi::Array<BufferRegion> new_regions;
new_regions.reserve(regions->size());
for (const auto& region : *regions) {
BufferRegion buffer_region = region;
BufferRegionNode* p = buffer_region.CopyOnWrite();
RewriteBufferRegion(&p->buffer, &p->region);
new_regions.push_back(buffer_region);
}
*regions = std::move(new_regions);
}
void RewriteMatchBuffers(ffi::Array<MatchBufferRegion>* match_buffers) const {
ffi::Array<MatchBufferRegion> result;
result.reserve(match_buffers->size());
for (const auto& match_buffer : *match_buffers) {
const BufferRegion& buffer_region = match_buffer->source;
auto p = ffi::make_object<BufferRegionNode>(*buffer_region.get());
RewriteBufferRegion(&p->buffer, &p->region);
result.push_back(MatchBufferRegion(match_buffer->buffer, BufferRegion(p)));
}
*match_buffers = std::move(result);
}
/*! \brief Map buffer var to the allocation information about each buffer. */
std::unordered_map<Var, BufferAllocInfo> buffer_info_;
};
ffi::Array<PrimExpr> CalcStrides(const BufferAllocInfo& alloc_info,
const ffi::Array<PrimExpr>& shape) {
std::vector<PrimExpr> strides;
if (alloc_info.dim_aligns.size()) {
ICHECK(alloc_info.dim_aligns.size() == shape.size());
strides.resize(shape.size());
PrimExpr stride = make_const(shape[0].dtype(), 1);
for (size_t i = shape.size(); i != 0; --i) {
size_t dim = i - 1;
DimAlignInfo info = alloc_info.dim_aligns[dim];
int align_factor = info.align_factor;
int align_offset = info.align_offset;
if (align_factor != 0) {
PrimExpr factor = make_const(stride.dtype(), align_factor);
PrimExpr offset = make_const(stride.dtype(), align_offset);
stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
}
strides[dim] = stride;
stride = stride * shape[dim];
}
}
return strides;
}
Stmt BufferCompactorCompact(
const PrimFunc& f,
const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& regions,
const std::unordered_map<Var, StorageAlignAnnotation>& storage_align) {
// collect buffer allocation info for no-alias buffers
std::unordered_map<Var, BufferAllocInfo> buffer_info;
for (const auto& kv : regions) {
const Buffer& buffer = kv.first;
// set dim alignment info
Region region = kv.second;
BufferAllocInfo alloc_info;
auto it = storage_align.find(buffer->data);
if (it != storage_align.end()) {
std::vector<DimAlignInfo> dim_aligns(buffer->shape.size());
for (const StorageAlignTuple& dim_align : (*it).second) {
int dim = dim_align.get<1>();
int factor = dim_align.get<2>();
int offset = dim_align.get<3>();
dim_aligns.at(dim) = {factor, offset};
}
alloc_info.dim_aligns = std::move(dim_aligns);
}
// prepare new buffer
ffi::Array<PrimExpr> shape = region.Map([](const Range& range) { return range->extent; });
ffi::Array<PrimExpr> strides = CalcStrides(alloc_info, shape);
ObjectPtr<BufferNode> n = ffi::make_object<BufferNode>(*buffer.get());
n->shape = std::move(shape);
n->strides = std::move(strides);
alloc_info.new_buffer = Buffer(std::move(n));
alloc_info.region = region;
buffer_info.emplace(buffer->data, std::move(alloc_info));
}
BufferCompactor compactor(std::move(buffer_info));
Stmt stmt = compactor(f->body);
return stmt;
}
namespace transform {
Pass CompactBufferAllocation(bool is_strict) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
PrimFuncNode* fptr = f.CopyOnWrite();
auto region = BufferAccessRegionCollector::Collect(f, /*collect_inbound=*/is_strict);
auto storage_align = CollectStorageAlignAnnotation(f->body);
fptr->body = BufferCompactorCompact(f, region, storage_align);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "s_tir.CompactBufferAllocation", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("s_tir.transform.CompactBufferAllocation", CompactBufferAllocation);
}
} // namespace transform
} // namespace s_tir
} // namespace tvm