blob: 498de4796cd4683a6ca0d27848479167a1af1e51 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <array>
#include <stack>
#include "../../runtime/thread_storage_scope.h"
#include "../schedule/utils.h"
#include "./ir_utils.h"
#include "./memhammer_rewrite_rule.h"
#include "tvm/tir/stmt.h"
namespace tvm {
namespace tir {
using support::NDIntSet;
// rewrite rules
static InverseMapping inverse_mapping;
static CoalescedAccess coalesced_access;
static CreateLocalStage create_local_stage;
static SharedToWmma shared_to_wmma;
static WmmaToGlobal wmma_to_global;
static WmmaToShared wmma_to_shared;
static MmaToGlobal mma_to_global;
/*!
* \brief A class to perform auto padding.
*
* One simple way to perform auto padding is to fix each padding size for each dimension at the
* same time, calculate the precise access index and the bank conflict,
* and choose the one with minimal conflict. However, this algorithm has exponential complexity.
* Suppose we have d dimensions and the padding size is 0-31, we need to calculate bank
* conflict for 32^{d-1} times.
* We propose a fast incremental algorithm that works for affine inputs, and it only calculate
* bank conflict for 32*{d-1} times. To be specific, we first decide the optimal padding size for
* dimension d-2, then for dimension d-3, ..., finally for dimension 0. It involves 2 steps.
*
* First, we analyze how a typical warp accesses the shared memory banks.
* A typical warp means setting all irrelevant loop vars to 0, and only keeps the threads in a warp.
* For each dimension, the access index is represented by
* x_1 * scale_1 + ... + x_n * scale_n (x_i is loop var)
* Note: The affine property guarantees that {x_i} must be independent,
* otherwise the algorithm is wrong.
* We will use this information to keep a list for each dimension called "iteration space" that
* records the resulting index as x_i takes each possible value.
*
* For example, the index is [outer*2+ty, tx*4+vec], where tx is threadIdx.x, and ty is threadIdx.y.
* tx is in [0, 16), and ty is in [0, 2).
* We will first get a warp access [ty, tx*4] because outer and vec are irrelevant loop vars.
* It's obvious that ty, tx*4 are both in the form of x_1 * scale_1 + ... + x_n * scale_n.
* In this case, we will keep lists {{0, 1}, {0, 4, ..., 60}}
*
* Next, we choose a padding size that has minimal conflict from the last dimension to first one.
* To calculate the conflict, we calculate the Cartesian product of the iteration space of all
* dimensions not higher than this. Each single point of product space represents access index
* of a particular thread, by which we can calculate the accessed memory bank. The conflict is
* the highest access frequency among the banks.
*
*/
class AutoPadder {
public:
/**
* \brief Do padding to the given buffers in shard memory
* \param buffers the given buffers
* \return the list of new padded buffers
*/
ffi::Array<Buffer> PadSharedMemory(const ffi::Array<Buffer>& buffers) {
ffi::Array<Buffer> result;
for (const Buffer& buffer : buffers) {
runtime::StorageScope scope = runtime::StorageScope::Create(buffer.scope());
if (scope.rank == runtime::StorageRank::kShared) {
auto iter_spaces = iter_spaces_[buffer.get()];
if (iter_spaces.empty()) {
result.push_back(buffer);
continue;
}
// The access index represented by points in the cartesian product of lower dimension
// iteration spaces
std::vector<std::vector<int>> low_dim_iter_space(iter_spaces.size(), std::vector<int>());
int n = buffer->shape.size();
int data_bits = buffer->dtype.bits();
// Step 1. initialize `low_dim_iter_space` with the iteration space of the last dim
for (int i = 0; i < static_cast<int>(iter_spaces.size()); i++) {
auto last_dim_iter_space = iter_spaces[i][n - 1];
low_dim_iter_space[i] = last_dim_iter_space;
}
PrimExpr stride = 1;
ffi::Array<PrimExpr> reverse_strides;
int pad_min = padding_min_.Get(buffer).value_or(Integer(1)).IntValue();
// Step 2. For each dimension, select a padding that has minimal bank conflict
for (int k = n - 2; k >= 0; k--) { // dims
int max_pad_size =
std::min(static_cast<int>(max_pad_factor_ *
(stride * buffer->shape[k + 1]).as<IntImmNode>()->value),
32 * 32 / data_bits);
int min_conflict = INT32_MAX;
int min_conflict_pad = -1;
for (int pad = 0; pad <= max_pad_size; pad += pad_min) { // select padding
int padded_stride = ((stride * buffer->shape[k + 1]).as<IntImmNode>()->value + pad) %
(32 * 32 / data_bits);
int conflict = 0;
for (int i = 0; i < static_cast<int>(iter_spaces.size()); i++) { // accesses
auto iter_space = iter_spaces[i][k];
int bank[32]{0};
for (int v1 : iter_space) {
for (int v2 : low_dim_iter_space[i]) {
int comb = (v1 * padded_stride + v2) * data_bits / 32 % 32;
bank[comb]++;
}
}
for (int j = 0; j < 32; j++) {
conflict = std::max(conflict, bank[j]);
}
}
if (conflict < min_conflict) {
min_conflict = conflict;
min_conflict_pad = pad;
}
}
// update low_dim_iter_space with
for (int i = 0; i < static_cast<int>(iter_spaces.size()); i++) { // accesses
auto iter_space = iter_spaces[i][k];
if (!iter_space.empty()) {
int padded_stride =
((stride * buffer->shape[k + 1]).as<IntImmNode>()->value + min_conflict_pad) %
(32 * 32 / data_bits);
std::vector<int> span;
for (int v1 : iter_space) {
for (int v2 : low_dim_iter_space[i]) {
span.push_back(((v1 * padded_stride + v2) * data_bits) % (32 * 32 / data_bits));
}
}
low_dim_iter_space[i] = span;
}
}
stride = stride * buffer->shape[k + 1] + min_conflict_pad;
reverse_strides.push_back(stride);
}
// Step 3. create the new padded buffer
ObjectPtr<BufferNode> b = ffi::make_object<BufferNode>(*buffer.get());
ffi::Array<PrimExpr> strides;
for (int i = static_cast<int>(reverse_strides.size()) - 1; i >= 0; i--) {
strides.push_back(reverse_strides[i]);
}
strides.push_back(1);
b->strides = strides;
Buffer new_buffer(b);
result.push_back(new_buffer);
padded_buffer_map_.Set(buffer, new_buffer);
} else {
result.push_back(buffer);
}
}
return result;
}
/**
* \brief Replace all occurrence of the old buffer with the new buffer in the stmt
* \param stmt the stmt to do replacement
* \return the stmt after replacement
*/
Stmt RewriteBufferAccess(const Stmt& stmt) {
class Rewriter : public StmtExprMutator {
public:
explicit Rewriter(const ffi::Map<Buffer, Buffer>& buffer_map) : buffer_map_(buffer_map) {}
private:
PrimExpr VisitExpr_(const BufferLoadNode* _op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_op));
BufferLoadNode* op = load.CopyOnWrite();
if (buffer_map_.count(op->buffer)) {
op->buffer = buffer_map_[op->buffer];
}
return load;
}
Stmt VisitStmt_(const BufferStoreNode* _op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_op));
BufferStoreNode* op = store.CopyOnWrite();
if (buffer_map_.count(op->buffer)) {
op->buffer = buffer_map_[op->buffer];
}
return store;
}
Stmt VisitStmt_(const BlockNode* op) final {
// To reduce the number of blocks in block sref reuse map, we check whether the block is
// really mutated (i.e., the old buffer appears in the block). If so, we return the block
// after mutation. Otherwise we just return the original block.
bool changed = false;
// Step 1. Mutate the read region.
ffi::Array<BufferRegion> reads;
for (const BufferRegion& read : op->reads) {
if (buffer_map_.count(read->buffer)) {
changed = true;
reads.push_back(BufferRegion(buffer_map_[read->buffer], read->region));
} else {
reads.push_back(read);
}
}
// Step 2. Mutate the write region.
ffi::Array<BufferRegion> writes;
for (const BufferRegion& write : op->writes) {
if (buffer_map_.count(write->buffer)) {
changed = true;
writes.push_back(BufferRegion(buffer_map_[write->buffer], write->region));
} else {
writes.push_back(write);
}
}
// Step 4. Mutate `match_buffers`. If an old buffer appears as a source of
// MatchBufferRegion, the storage scope of the target buffer also needs to be set.
ffi::Array<MatchBufferRegion> match_buffers;
for (const MatchBufferRegion& match_buffer : op->match_buffers) {
if (buffer_map_.count(match_buffer->source->buffer)) {
changed = true;
Buffer new_buffer = buffer_map_[match_buffer->source->buffer];
match_buffers.push_back(MatchBufferRegion(
match_buffer->buffer, BufferRegion(new_buffer, match_buffer->source->region)));
} else {
match_buffers.push_back(match_buffer);
}
}
// Step 5. Recursively mutate the block.
Stmt res = StmtMutator::VisitStmt_(op);
if (res.get() != op) {
changed = true;
}
if (changed) {
ObjectPtr<BlockNode> block = CopyOnWrite(res.as<BlockNode>());
block->reads = std::move(reads);
block->writes = std::move(writes);
block->match_buffers = std::move(match_buffers);
return Stmt(block);
} else {
return ffi::GetRef<Block>(op);
}
}
const ffi::Map<Buffer, Buffer>& buffer_map_;
};
Rewriter rewriter(padded_buffer_map_);
return rewriter(stmt);
}
/**
* \brief an equivalent of scale * loop_var with loop_var: {min=0, extent=extent}
*/
struct Pattern {
int extent;
int scale;
};
/**
* \brief Collect pattern from indices
*/
class PatternCollector : public StmtExprVisitor {
void VisitExpr_(const VarNode* op) final {
if (!success_) {
return;
}
int extent = var_range_[ffi::GetRef<Var>(op)]->extent.as<IntImmNode>()->value;
if (extent > 1) {
stack_.push({{extent, 1}});
} else {
stack_.push({});
}
}
void VisitExpr_(const AddNode* op) final {
ExprVisitor::VisitExpr_(op);
if (!success_) {
return;
}
std::vector<Pattern> merged_patterns;
std::vector<Pattern> r = stack_.top();
stack_.pop();
std::vector<Pattern> l = stack_.top();
stack_.pop();
for (const Pattern& pattern : l) {
merged_patterns.push_back(pattern);
}
for (const Pattern& pattern : r) {
merged_patterns.push_back(pattern);
}
if (merged_patterns.empty()) {
stack_.push({});
return;
}
std::vector<Pattern> ret;
ret.push_back(merged_patterns[0]);
for (int i = 0; i < static_cast<int>(merged_patterns.size()); i++) {
Pattern prev_pattern = ret.back();
if (merged_patterns[i].extent * merged_patterns[i].scale == prev_pattern.scale) {
ret.pop_back();
ret.push_back(
{prev_pattern.extent * merged_patterns[i].extent, merged_patterns[i].scale});
}
}
stack_.push(ret);
}
void VisitExpr_(const FloorDivNode* op) final {
ExprVisitor::VisitExpr_(op);
if (!success_) {
return;
}
std::vector<Pattern> inner = stack_.top();
stack_.pop();
int lower_factor = op->b.as<IntImmNode>()->value;
std::vector<Pattern> ret;
for (const Pattern& pattern : inner) {
if (pattern.scale >= lower_factor) {
if (pattern.scale % lower_factor == 0) {
ret.push_back({pattern.extent, pattern.scale / lower_factor});
} else {
success_ = false;
}
} else if (pattern.scale * pattern.extent > lower_factor) {
if ((pattern.scale * pattern.extent) % lower_factor == 0) {
ret.push_back({pattern.extent * pattern.scale / lower_factor, 1});
} else {
success_ = false;
}
}
}
stack_.push(ret);
}
void VisitExpr_(const FloorModNode* op) final {
ExprVisitor::VisitExpr_(op);
if (!success_) {
return;
}
std::vector<Pattern> inner = stack_.top();
stack_.pop();
int extent = op->b.as<IntImmNode>()->value;
std::vector<Pattern> ret;
for (const Pattern& pattern : inner) {
if (pattern.scale < extent) {
if (extent % pattern.scale == 0) {
if (extent / pattern.scale < pattern.extent) {
ret.push_back({extent / pattern.scale, pattern.scale});
} else {
ret.push_back({pattern.extent, pattern.scale});
}
} else {
success_ = false;
}
}
}
stack_.push(ret);
}
void VisitExpr_(const MulNode* op) final {
ExprVisitor::VisitExpr_(op);
if (!success_) {
return;
}
std::vector<Pattern> inner = stack_.top();
stack_.pop();
int scale = op->b.as<IntImmNode>()->value;
std::vector<Pattern> ret;
for (const Pattern& pattern : inner) {
ret.push_back({pattern.extent, pattern.scale * scale});
}
stack_.push(ret);
}
public:
explicit PatternCollector(const ffi::Map<Var, Range>& var_range) : var_range_(var_range) {}
/*!
* \brief Collect the iteration space for given indices. The iteration space is the possible
* values that an index can take (do not remove duplicate).
* For example, the input is [ty, tx*4], where tx is in [0, 16), and ty is in [0, 2).
* The output would be {{0, 1}, {0, 4, ..., 60}}
* \param indices The indices to analyze
* \param var_range The range of loop variables
* \param data_bits The size of dtype in bits
* \return The iteration space. The first array represents dimensions, and the second array
* represents the iteration space of one dimension
*/
static std::vector<std::vector<int>> CollectIterationSpace(
const ffi::Array<PrimExpr>& indices, const ffi::Map<Var, Range>& var_range, int data_bits) {
PatternCollector collector(var_range);
std::vector<std::vector<int>> ret;
for (int i = 0; i < static_cast<int>(indices.size()); i++) {
collector(indices[i]);
if (collector.success_ && collector.stack_.size() == 1) {
auto patterns = collector.stack_.top();
int extent_prod = 1;
for (const Pattern& p : patterns) {
extent_prod *= p.extent;
}
std::vector<int> iter_space;
for (int thread_id = 0; thread_id < extent_prod; thread_id++) {
int index = 0;
int n = thread_id;
for (int j = static_cast<int>(patterns.size()) - 1; j >= 0; j--) {
int val = n % patterns[j].extent;
index += val * patterns[j].scale;
n /= patterns[j].extent;
}
iter_space.push_back(index);
}
ret.push_back(iter_space);
collector.stack_.pop();
} else {
ret.push_back({});
}
}
return ret;
}
std::stack<std::vector<Pattern>> stack_;
const ffi::Map<Var, Range>& var_range_;
bool success_ = true;
};
/*! A utility class for calling CollectIterationSpace to each buffer access*/
class IterSpaceAnalyzer : public StmtExprVisitor {
public:
IterSpaceAnalyzer(const ffi::Map<Var, PrimExpr>& substitute_map, AutoPadder* self,
int data_bits, const ffi::Map<ffi::String, Integer> warp_thread_extent)
: substitute_map_(substitute_map),
self(self),
data_bits_(data_bits),
warp_thread_extent_(warp_thread_extent) {}
private:
bool CheckVarContiguous(PrimExpr e, Var var, const ffi::Map<Var, PrimExpr>& subst_map) {
PrimExpr e1 = Substitute(e, [var](const Var& v) -> ffi::Optional<PrimExpr> {
if (v.same_as(var)) {
return Integer(0);
} else {
return v;
}
});
PrimExpr e2 = Substitute(e, [var](const Var& v) -> ffi::Optional<PrimExpr> {
if (v.same_as(var)) {
return Integer(1);
} else {
return v;
}
});
arith::Analyzer analyzer;
return !analyzer.CanProve(Substitute(e2 - e1, subst_map) != 1);
}
void VisitStmt_(const ForNode* op) final {
if (op->kind != ForKind::kThreadBinding) {
substitute_map_.Set(op->loop_var, op->min);
} else {
Integer extent =
warp_thread_extent_.Get(op->thread_binding.value()->thread_tag).value_or(1);
var_range_.Set(op->loop_var, Range::FromMinExtent(op->min, extent));
}
if (op->kind == ForKind::kVectorized) {
vector_var = op->loop_var;
vector_length_ = op->extent.as<IntImmNode>()->value;
}
StmtExprVisitor::VisitStmt_(op);
if (op->kind == ForKind::kVectorized) {
vector_length_ = -1;
}
if (op->kind != ForKind::kThreadBinding) {
substitute_map_.erase(op->loop_var);
}
}
/*!
* \brief Take a typical warp and collect the iteration space for buffer store
* For example, the access is A[outer*2+ty, tx*4+vec] = xxx, where tx is threadIdx.x, and ty is
* threadIdx.y. tx is in [0, 16), and ty is in [0, 2).
* The iteration space would be {{0, 1}, {0, 4, ..., 60}}.
* \param op the buffer store
*/
void VisitStmt_(const BufferStoreNode* op) final {
runtime::StorageScope scope = runtime::StorageScope::Create(op->buffer.scope());
if (scope.rank == runtime::StorageRank::kShared) {
ffi::Array<PrimExpr> substitued_indices;
arith::Analyzer analyzer;
for (const PrimExpr& e : op->indices) {
substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_)));
}
std::vector<std::vector<int>> iter_space =
PatternCollector::CollectIterationSpace(substitued_indices, var_range_, data_bits_);
if (!iter_space.empty()) {
self->iter_spaces_[op->buffer.get()].push_back(iter_space);
}
if (vector_length_ != -1 &&
CheckVarContiguous(op->indices.back(), vector_var, substitute_map_)) {
Integer m = self->padding_min_.Get(op->buffer).value_or(1);
self->padding_min_.Set(op->buffer, Downcast<Integer>(max(vector_length_, m)));
}
}
StmtExprVisitor::VisitStmt_(op);
}
/*!
* \brief Take a typical warp and collect the iteration space for buffer load
* For example, the access is xxx = A[outer*2+ty, tx*4+vec], where tx is threadIdx.x, and ty is
* threadIdx.y. tx is in [0, 16), and ty is in [0, 2).
* The iteration space would be {{0, 1}, {0, 4, ..., 60}}.
* \param op the buffer load
*/
void VisitExpr_(const BufferLoadNode* op) final {
runtime::StorageScope scope = runtime::StorageScope::Create(op->buffer.scope());
if (scope.rank == runtime::StorageRank::kShared) {
ffi::Array<PrimExpr> substitued_indices;
arith::Analyzer analyzer;
for (const PrimExpr& e : op->indices) {
substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_)));
}
std::vector<std::vector<int>> iter_space =
PatternCollector::CollectIterationSpace(substitued_indices, var_range_, data_bits_);
if (!iter_space.empty()) {
self->iter_spaces_[op->buffer.get()].push_back(iter_space);
}
if (vector_length_ != -1 &&
CheckVarContiguous(substitued_indices.back(), vector_var, substitute_map_)) {
Integer m = self->padding_min_.Get(op->buffer).value_or(1);
self->padding_min_.Set(op->buffer, Downcast<Integer>(max(vector_length_, m)));
}
}
StmtExprVisitor::VisitExpr_(op);
}
/*!
* \brief Take a typical warp and collect the iteration space for load_matrix_sync and
* store_matrix_sync
* For example, the access region is A[y*16+16, x*16+16], where y and x are not bound to
* threadIdx. The iteration space would be {{0, 1, ..., 15}, {0, 1, ..., 15}}.
* \param op the call node
*/
void VisitStmt_(const BlockNode* op) final {
if (const auto* eval = op->body.as<EvaluateNode>()) {
if (const auto* call = eval->value.as<CallNode>()) {
if (call->op == builtin::tvm_load_matrix_sync() ||
call->op == builtin::tvm_store_matrix_sync()) {
for (const MatchBufferRegion& r : op->match_buffers) {
Buffer src_buffer = r->source->buffer;
runtime::StorageScope scope = runtime::StorageScope::Create(src_buffer.scope());
if (scope.rank == runtime::StorageRank::kShared) {
Region region = r->source->region;
ffi::Array<PrimExpr> indices;
for (int i = 0; i < static_cast<int>(region.size()); i++) {
Var var("region" + std::to_string(i));
indices.push_back(region[i]->min + var);
var_range_.Set(var, Range::FromMinExtent(0, region[i]->extent));
}
ffi::Array<PrimExpr> substitued_indices;
arith::Analyzer analyzer;
for (const PrimExpr& e : indices) {
substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_)));
}
std::vector<std::vector<int>> iter_space = PatternCollector::CollectIterationSpace(
substitued_indices, var_range_, data_bits_);
if (!iter_space.empty()) {
self->iter_spaces_[src_buffer.get()].push_back(iter_space);
}
}
}
}
}
}
}
ffi::Map<Var, PrimExpr> substitute_map_;
AutoPadder* self;
int data_bits_;
ffi::Map<ffi::String, Integer> warp_thread_extent_;
ffi::Map<Var, Range> var_range_;
int vector_length_ = -1;
Var vector_var;
};
/*!
* \brief Analyze the shared memory access
* \param stmt The data copy
* \param outer_loops The outer loops of the stmt
* \param data_bits The length of dtype in bits
* \param thread_extent The extents of all thread binding loops
*/
void AnalyzeSharedMemoryAccess(const Stmt& stmt, const ffi::Array<For>& outer_loops,
int data_bits,
const ffi::Map<ffi::String, Integer>& thread_extent) {
ffi::Map<ffi::String, Integer> warp_thread_extent;
Integer prod = 1;
ffi::Array<ffi::String> thread_tags{"threadIdx.x", "threadIdx.y", "threadIdx.z"};
arith::Analyzer analyzer;
for (int i = 0; i < 3; i++) {
Integer extent = thread_extent.Get(thread_tags[i]).value_or(1);
if (analyzer.CanProve(prod * extent >= 32)) {
warp_thread_extent.Set(thread_tags[i], Downcast<Integer>(floordiv(32, prod)));
prod *= floordiv(32, prod);
break;
} else {
warp_thread_extent.Set(thread_tags[i], Downcast<Integer>(extent));
prod *= extent;
}
}
ffi::Map<Var, PrimExpr> substitute_map;
for (const For& loop : outer_loops) {
substitute_map.Set(loop->loop_var, loop->min);
}
IterSpaceAnalyzer iter_space_analyzer(substitute_map, this, data_bits, warp_thread_extent);
iter_space_analyzer(stmt);
}
private:
/*! \brief A map from the old buffers to the new padded buffers */
ffi::Map<Buffer, Buffer> padded_buffer_map_;
/*! \brief A map from each buffer to the iteration spaces of the accesses*/
std::unordered_map<const BufferNode*, std::vector<std::vector<std::vector<int>>>> iter_spaces_;
/*! \brief A map from each buffer to their minimal padding size */
ffi::Map<Buffer, Integer> padding_min_;
/*! \brief max padding size in relative to the original shape*/
const double max_pad_factor_ = 0.25;
friend class AutoCopyMutator;
};
class AutoCopyMutator : public StmtExprMutator {
public:
explicit AutoCopyMutator(ffi::Map<ffi::String, Integer> thread_extent)
: thread_extent_(thread_extent) {}
/**
* \brief Replace old buffers with padded buffers in the stmt
* \param stmt The stmt to rewrite
* \return The stmt after rewrite
*/
Stmt RewritePaddingBody(const Stmt& stmt) { return padder.RewriteBufferAccess(stmt); }
private:
Stmt VisitStmt_(const BlockNode* op) final {
Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
// only rewrite the block annotated with "auto_copy"
if (!GetAnn<bool>(op, tir::attr::auto_copy).value_or(false)) {
BlockNode* n = block.CopyOnWrite();
n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers));
return block;
}
ICHECK_EQ(block->writes.size(), 1);
ICHECK_GE(block->reads.size(), 1);
BufferRegion target_read = block->reads[0];
if (block->reads.size() > 1) {
bool found = false;
for (size_t i = 0; i < block->reads.size(); i++) {
if (block->reads[i]->buffer.scope() == "wmma.accumulator") {
found = true;
target_read = block->reads[i];
}
}
ICHECK(found) << "Multiple buffer read";
}
int data_bits = target_read->buffer->dtype.bits();
ConstraintSet constraints(this->thread_extent_, //
this->outer_loops_, //
target_read, //
block->writes[0], //
data_bits, //
block->annotations);
BlockNode* n = block.CopyOnWrite();
OutputSet outputs;
for (RewriteRule* rule : rules) {
n->body = rule->Apply(std::move(n->body), constraints, &outputs);
}
for (const Buffer& buffer : outputs.alloc_buffer) {
n->alloc_buffers.push_back(buffer);
}
for (const auto& p : outputs.padding_min) {
Integer m = padder.padding_min_.Get(p.first).value_or(1);
padder.padding_min_.Set(p.first, Downcast<Integer>(max(p.second, m)));
}
padder.AnalyzeSharedMemoryAccess(block->body, outer_loops_, data_bits, thread_extent_);
n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers));
return block;
}
Stmt VisitStmt_(const ForNode* op) final {
outer_loops_.push_back(ffi::GetRef<For>(op));
Stmt stmt = StmtMutator::VisitStmt_(op);
outer_loops_.pop_back();
return stmt;
}
/*! \brief Thread extents collected. */
ffi::Map<ffi::String, Integer> thread_extent_;
/*! \brief The outer loops during recursive visit */
ffi::Array<For> outer_loops_;
/*! \brief Calculating optimal padding size */
AutoPadder padder;
/*! \brief All rewrite rules. */
const std::array<RewriteRule*, 7> rules = {&inverse_mapping, //
&coalesced_access, //
&create_local_stage, //
&shared_to_wmma, //
&wmma_to_global, //
&wmma_to_shared, //
&mma_to_global};
};
/*!
* \brief Collect the extent for all thread binding loops.
*/
class ThreadExtentCollector : public StmtVisitor {
public:
static ffi::Map<ffi::String, Integer> CollectThreadExtent(const Stmt& stmt) {
ThreadExtentCollector collector;
collector(stmt);
return collector.thread_extent_;
}
private:
void VisitStmt_(const BlockNode* op) final {
if (ffi::Optional<Integer> warp_execution = GetAnn<Integer>(op, "warp_execution")) {
if (warp_execution.value()->value != 0) {
thread_extent_.Set("threadIdx.x", Integer(32));
}
}
StmtVisitor::VisitStmt_(op);
}
void VisitStmt_(const ForNode* op) final {
if (op->thread_binding.defined() && op->thread_binding.value()->iter_type == kThreadIndex) {
if (const auto* extent = op->extent.as<IntImmNode>()) {
thread_extent_.Set(op->thread_binding.value()->thread_tag, ffi::GetRef<Integer>(extent));
}
}
StmtVisitor::VisitStmt_(op);
}
/*! \brief the map from thread tag to its extent */
ffi::Map<ffi::String, Integer> thread_extent_;
};
namespace transform {
Pass LowerAutoCopy() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
AutoCopyMutator mutator(ThreadExtentCollector::CollectThreadExtent(n->body));
n->body = mutator(std::move(n->body));
n->body = mutator.RewritePaddingBody(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerAutoCopy", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.LowerAutoCopy", LowerAutoCopy);
}
} // namespace transform
} // namespace tir
} // namespace tvm