blob: cdbe17508339675bc24badfc7e59d4d967e56748 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file inject_permuted_layout.cc
* \brief The pass injects permuted layout for shared memory buffers to avoid bank conflicts.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../../arith/ir_mutator_with_analyzer.h"
#include "../../runtime/thread_storage_scope.h"
#include "../../support/utils.h"
#include "ir_utils.h"
namespace tvm {
namespace tir {
using namespace arith;
using namespace runtime;
class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
public:
static PrimFunc Transform(PrimFunc func) {
Analyzer analyzer;
auto new_body = PermutedLayoutInjector(func, &analyzer)(func->body);
auto func_node = func.CopyOnWrite();
func_node->body = new_body;
return func;
}
private:
explicit PermutedLayoutInjector(PrimFunc func, Analyzer* analyzer)
: IRMutatorWithAnalyzer(analyzer) {
buffer_map_.insert(func->buffer_map.begin(), func->buffer_map.end());
}
using IRMutatorWithAnalyzer::VisitExpr_;
using IRMutatorWithAnalyzer::VisitStmt_;
ffi::Array<PrimExpr> PermuteIndices(PrimExpr row_idx, PrimExpr col_idx, int row_size) {
ICHECK(permute_);
// Index after vectorizing by 8
PrimExpr col_idx_outer = floordiv(col_idx, VECTORIZE_FACTOR),
col_idx_inner = floormod(col_idx, VECTORIZE_FACTOR);
PrimExpr new_col_idx_outer;
if (row_size % 64 == 0) {
// Use 8 * 8 permuted layout
// Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
// Every row below corresponds to 32 banks
// 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
// 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
// 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
// 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4
// 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3
// 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
// 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
// 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
auto row_idx_sub = floormod(row_idx, 8);
new_col_idx_outer = col_idx_outer ^ row_idx_sub;
} else {
ICHECK(row_size % 32 == 0);
// Use 8 * 4 permuted layout
// Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
// Every row below corresponds to 16 banks
// 0 1 2 3 ==> 0 1 2 3
// 0 1 2 3 ==> 0 1 2 3
// 0 1 2 3 ==> 1 0 3 2
// 0 1 2 3 ==> 1 0 3 2
// 0 1 2 3 ==> 2 3 0 1
// 0 1 2 3 ==> 2 3 0 1
// 0 1 2 3 ==> 3 2 1 0
// 0 1 2 3 ==> 3 2 1 0
// View with 8 elements per row:
// 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
// 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
// 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
// 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
auto row_idx_sub = floormod(row_idx, 8);
new_col_idx_outer = col_idx_outer ^ floordiv(row_idx_sub, 2);
}
return {row_idx, analyzer_->Simplify(new_col_idx_outer * 8 + col_idx_inner)};
}
static bool CheckAnnotation(const Any& annotation) {
if (auto opt_str = annotation.as<ffi::String>()) {
// Support string annotation for backward compatibility
return *opt_str != "";
} else if (auto* node = annotation.as<IntImmNode>()) {
return node->value != 0;
} else if (auto opt_val = annotation.try_cast<int64_t>()) {
return *opt_val != 0;
} else {
LOG(FATAL) << "Invalid permuted layout annotation: " << annotation;
}
}
Stmt VisitStmt_(const BlockNode* op) final {
// Record the mapping from buffer data var to buffer for later lookup
for (auto buffer : op->alloc_buffers) {
buffer_map_.insert({buffer->data, buffer});
}
for (auto match_buffer : op->match_buffers) {
buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer});
}
if (op->annotations.count("permuted_layout") == 0 ||
!CheckAnnotation(op->annotations.at("permuted_layout"))) {
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
auto prev_permute = permute_;
permute_ = true;
Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));
permute_ = prev_permute;
// Erase the permuted_layout annotation after the pass
auto block_node = block.CopyOnWrite();
block_node->annotations.erase("permuted_layout");
return block;
}
int CheckAndGetBufferRowSize(Buffer buffer) {
CHECK(buffer->shape.size() >= 2)
<< "The dimension of Buffer \"" << buffer->name << "\" with shape " << buffer->shape
<< " should be at least 2";
auto dim = buffer->shape.size();
auto buffer_row_size = buffer->shape[dim - 1].as<IntImmNode>()->value;
auto buffer_col_size = buffer->shape[dim - 2].as<IntImmNode>()->value;
if (buffer_row_size % 64 != 0) {
CHECK(buffer_row_size % 32 == 0)
<< "Permuted Layout for Buffer \"" << buffer->name << "\" with shape " << buffer->shape
<< " is not supported since its second dimension is not divisible by 32";
CHECK(buffer_col_size % 2 == 0)
<< "Permuted Layout for Buffer \"" << buffer->name << "\" with shape " << buffer->shape
<< " is not supported since its first dimension is not divisible by 2 and second "
"dimension is not divisible by 64";
}
return buffer_row_size;
}
ffi::Array<PrimExpr> HandleBufferIndices(Buffer buffer, ffi::Array<PrimExpr> indices) {
auto buffer_row_size = CheckAndGetBufferRowSize(buffer);
// Mutate the last two indices
auto indices_size = indices.size();
PrimExpr row_idx = indices[indices_size - 2];
PrimExpr col_idx = indices[indices_size - 1];
auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size);
indices.Set(indices_size - 2, new_indices[0]);
indices.Set(indices_size - 1, new_indices[1]);
return indices;
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
// Rewrite write from global to shared.dyn or shared
// We assume the shape of the shared memory is [..., row_size, col_size],
// where row_size is divisible by 64, or divisible by 32 and col_size is divisible by 2.
auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (!permute_ || store->buffer->shape.size() < 2) {
return store;
}
auto scope = StorageScope::Create(GetPtrStorageScope(store->buffer->data));
if (scope.rank != StorageRank::kShared) {
return store;
}
auto store_node = store.CopyOnWrite();
store_node->indices = HandleBufferIndices(store_node->buffer, store_node->indices);
return store;
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
// Rewrite load from shared or shared.dyn to global
auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
if (!permute_ || load->buffer->shape.size() < 2) {
return load;
}
auto scope = StorageScope::Create(GetPtrStorageScope(load->buffer->data));
if (scope.rank != StorageRank::kShared) {
return load;
}
auto load_node = load.CopyOnWrite();
load_node->indices = HandleBufferIndices(load_node->buffer, load_node->indices);
return load;
}
PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr,
ffi::Optional<PrimExpr> offset = std::nullopt) {
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and accumulate it to
// smem_offset
CHECK(access_ptr->IsInstance<CallNode>())
<< "Invalid access ptr for permuted layout: " << access_ptr;
auto access_ptr_call = Downcast<Call>(access_ptr);
CHECK(access_ptr_call->op.same_as(builtin::tvm_access_ptr()))
<< "Invalid access ptr for permuted layout: " << access_ptr;
auto buffer_map_iter = buffer_map_.find(Downcast<Var>(access_ptr_call->args[1]));
CHECK(buffer_map_iter != buffer_map_.end())
<< "The buffer corresponding to data Var " << access_ptr_call->args[1] << " is not found";
int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second);
PrimExpr smem_offset = access_ptr_call->args[2] + (offset.defined() ? offset.value() : 0);
// Convert offset to 2-dimension, reindex it and convert it back
PrimExpr row_idx = floordiv(smem_offset, buffer_row_size);
PrimExpr col_idx = floormod(smem_offset, buffer_row_size);
auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size);
auto new_offset = analyzer_->Simplify(new_indices[0] * buffer_row_size + new_indices[1]);
auto new_access_ptr = access_ptr_call.CopyOnWrite();
new_access_ptr->args.Set(2, new_offset);
return access_ptr_call;
}
PrimExpr VisitExpr_(const CallNode* op) final {
// Rewrite from/to shared or shared.dyn to/from local
auto call = Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
if (!permute_) {
return call;
}
if (!call->op.same_as(builtin::ptx_ldmatrix()) && !call->op.same_as(builtin::mma_store())) {
return call;
}
if (call->op.same_as(builtin::ptx_ldmatrix())) {
// form: T.ptx_ldmatrix(..., smem_ptr, smem_offset)
// smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask)
auto access_ptr = call->args[5];
PrimExpr smem_offset = call->args[6];
auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, smem_offset);
auto new_call = call.CopyOnWrite();
new_call->args.Set(5, new_access_ptr);
new_call->args.Set(6, IntImm(smem_offset->dtype, 0));
return call;
} else if (call->op.same_as(builtin::mma_store())) {
// TODO(yixin): mma_store is not fully tested yet
// because we will directly store result to Buffer instead of calling mma_store now
auto access_ptr = call->args[2];
auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr);
auto new_call = call.CopyOnWrite();
new_call->args.Set(2, new_access_ptr);
return call;
} else {
LOG(FATAL) << "Invalid call node: " << call;
}
}
static constexpr size_t VECTORIZE_FACTOR = 8;
static constexpr size_t BANK_SIZE_BYTES = 128;
// Mapping from data Var of a Buffer to Buffer, for lookup
std::unordered_map<Var, Buffer> buffer_map_;
bool permute_ = false;
};
namespace transform {
Pass InjectPermutedLayout() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return PermutedLayoutInjector::Transform(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectPermutedLayout", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.InjectPermutedLayout", InjectPermutedLayout);
}
} // namespace transform
} // namespace tir
} // namespace tvm