blob: 43f7a103db7fe71d8dce277929778107d4c17dda [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Lower allreduce to device implementable ir.
* \file lower_thread_allreduce.cc
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include "../../runtime/thread_storage_scope.h"
#include "ir_utils.h"
#include "update_pointer_storage_scope.h"
namespace tvm {
namespace tir {
class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScope {
public:
explicit UpdatePointerStorageScopeAllReduce(
const std::unordered_map<const VarNode*, String>& new_storage_scopes)
: UpdatePointerStorageScope(new_storage_scopes) {}
Stmt VisitStmt_(const AllocateNode* op) final {
auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(op->buffer_var));
auto new_scope = GetPtrStorageScope(remapped);
if (new_scope != GetPtrStorageScope(op->buffer_var)) {
Stmt body = StmtExprMutator::VisitStmt(op->body);
if (new_scope == "shared") {
// use volatile access to shared buffer.
body = AttrStmt(remapped, attr::volatile_scope, 1, body);
}
return Allocate(remapped, op->dtype, op->extents, op->condition, body);
}
return StmtExprMutator::VisitStmt_(op);
}
};
class ThreadAllreduceBuilder final : public StmtExprMutator {
public:
explicit ThreadAllreduceBuilder(const TargetNode* target)
: target_(target),
warp_size_(target->GetAttr<Integer>("thread_warp_size", 1).value().IntValue()) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
thread_extents_.push_back(op);
Stmt ret = StmtExprMutator::VisitStmt_(op);
thread_extents_.pop_back();
return ret;
} else if (op->attr_key == attr::reduce_scope) {
const CommReducerNode* combiner = op->node.as<CommReducerNode>();
ICHECK(combiner);
reduce_combiner_.push_back(combiner);
Stmt ret = StmtExprMutator::VisitStmt_(op);
reduce_combiner_.pop_back();
return ret;
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const EvaluateNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<EvaluateNode>();
const CallNode* call = op->value.as<CallNode>();
if (call && call->op.same_as(builtin::tvm_thread_allreduce())) {
return MakeAllreduce(call);
} else {
return stmt;
}
}
Stmt VisitStmt_(const AllocateNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
const AllocateNode* repl = it->second.as<AllocateNode>();
if (warp_allocs_.count(repl)) {
new_storage_scopes_[repl->buffer_var.get()] = "local";
} else {
new_storage_scopes_[repl->buffer_var.get()] = "shared";
}
return Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body);
} else {
return stmt;
}
}
PrimExpr VisitExpr_(const LoadNode* op) final {
LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
return PrimExpr();
}
Stmt VisitStmt_(const StoreNode* op) final {
LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
return Stmt();
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
{
auto it = load_remap_.find(op->buffer->data.get());
if (it != load_remap_.end()) {
for (const auto& index : op->indices) {
ICHECK(is_zero(index));
}
return it->second;
}
}
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
op = load.get();
{
auto it = buf_remap_.find(op->buffer.get());
if (it != buf_remap_.end()) {
return BufferLoad(it->second, op->indices, op->span);
}
}
{
auto it = var_remap_.find(op->buffer->data.get());
if (it != var_remap_.end()) {
Buffer remapped_buffer(it->second, op->buffer->dtype, op->buffer->shape,
op->buffer->strides, op->buffer->elem_offset, op->buffer->name,
op->buffer->data_alignment, op->buffer->offset_factor,
op->buffer->buffer_type, op->buffer->axis_separators,
op->buffer->span);
buf_remap_[op->buffer.get()] = remapped_buffer;
return BufferLoad(remapped_buffer, op->indices, op->span);
}
}
return StmtExprMutator::VisitExpr_(op);
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto it = store_remap_.find(store->buffer.get());
if (it != store_remap_.end()) {
for (const auto& index : op->indices) {
ICHECK(is_zero(index));
}
auto writer = store.CopyOnWrite();
writer->buffer = it->second;
return std::move(store);
}
{
auto it = buf_remap_.find(store->buffer.get());
if (it != buf_remap_.end()) {
return BufferStore(it->second, store->value, store->indices, store->span);
}
}
{
auto it = var_remap_.find(store->buffer->data.get());
if (it != var_remap_.end()) {
Buffer remapped_buffer(it->second, store->buffer->dtype, store->buffer->shape,
store->buffer->strides, store->buffer->elem_offset,
store->buffer->name, store->buffer->data_alignment,
store->buffer->offset_factor, store->buffer->buffer_type,
store->buffer->axis_separators, store->buffer->span);
buf_remap_[store->buffer.get()] = remapped_buffer;
return BufferStore(remapped_buffer, store->value, store->indices, store->span);
}
}
return std::move(store);
}
std::unordered_map<const VarNode*, String> new_storage_scopes_;
private:
// Thread entry
struct ThreadEntry {
runtime::ThreadScope scope;
IterVar iv;
int extent;
// comparator
bool operator<(const ThreadEntry& other) const {
return scope.dim_index < other.scope.dim_index;
}
};
// make allreduce.
Stmt MakeAllreduce(const CallNode* call) {
ICHECK(!reduce_combiner_.empty());
const CommReducerNode* combiner = reduce_combiner_.back();
size_t size = combiner->result.size();
const IntImmNode* size_of_args = call->args[0].as<IntImmNode>();
ICHECK(size_of_args) << call->args[0]->GetTypeKey();
ICHECK_EQ(size, size_of_args->value);
Array<PrimExpr> inits = combiner->identity_element;
std::vector<PrimExpr> values(size);
std::vector<DataType> types(size);
PrimExpr cond = call->args[size + 1];
for (size_t idx = 0; idx < size; ++idx) {
values[idx] = call->args[1 + idx];
if (!is_one(cond)) {
values[idx] = Select(cond, values[idx], inits[idx]);
}
types[idx] = values[idx].dtype();
}
std::vector<Buffer> buffers(size);
for (size_t idx = 0; idx < size; ++idx) {
PrimExpr arg = call->args[2 + size + idx];
// Loads from boolean buffers may have cast nodes inserted by
// earlier passes.
if (auto cast = arg.as<CastNode>()) {
arg = cast->value;
}
buffers[idx] = Downcast<BufferLoad>(arg)->buffer;
}
std::unordered_set<const VarNode*> reduce_set;
for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
const VarNode* v = call->args[i].as<VarNode>();
// The simply optimization replace a iteration variable with a constant
// when extent of the iteration is 1. As threaded IterVar always started from 0,
// we can just ignore this variable in this case.
if (v) {
reduce_set.insert(v);
} else {
ICHECK(call->args[i].as<IntImmNode>() && call->args[i].as<IntImmNode>()->value == 0)
<< "arg" << i << "should be a VarNode or IntImmNode";
}
}
size_t nmatch = 0;
std::vector<ThreadEntry> vred, vpar;
for (const AttrStmtNode* attr : thread_extents_) {
ThreadEntry e;
IterVar iv = Downcast<IterVar>(attr->node);
e.scope = runtime::ThreadScope::Create(iv->thread_tag);
e.iv = iv;
ICHECK_LE(e.scope.rank, 1);
ICHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction";
if (e.scope.rank == 1) {
const auto* ptr = attr->value.as<IntImmNode>();
ICHECK(ptr) << "Need constant extent for reduce set " << iv;
e.extent = static_cast<int>(ptr->value);
// ignore variables equal to 0
if (e.extent == 1) {
continue;
}
if (reduce_set.count(iv->var.get())) {
vred.push_back(e);
++nmatch;
} else {
vpar.push_back(e);
}
}
}
ICHECK_EQ(nmatch, reduce_set.size()) << "Not all reduce index are presented in the context";
std::sort(vred.begin(), vred.end());
std::sort(vpar.begin(), vpar.end());
// the size of each index.
int reduce_extent, group_extent;
PrimExpr reduce_index = FlattenThread(vred, &reduce_extent);
PrimExpr group_index = FlattenThread(vpar, &group_extent);
// the longest contiguous reduce extent after flattening
int contiguous_reduce_extent = 1;
std::vector<std::tuple<int, int, bool>> block_threads; // tuple(dim_index, extent, is_reduce)
for (const ThreadEntry& thr : vred) {
if (thr.scope.rank == 1) { // threadIdx
block_threads.emplace_back(thr.scope.dim_index, thr.extent, true);
}
}
for (const ThreadEntry& thr : vpar) {
if (thr.scope.rank == 1) { // threadIdx
block_threads.emplace_back(thr.scope.dim_index, thr.extent, false);
}
}
// sort according to dim_index
std::sort(block_threads.begin(), block_threads.end());
for (auto&& thr_attr : block_threads) {
int dim_index, extent;
bool is_reduce;
std::tie(dim_index, extent, is_reduce) = thr_attr;
if (is_reduce) {
contiguous_reduce_extent *= extent;
} else {
break;
}
}
std::vector<Stmt> seq;
std::vector<Var> shared_buffer_vars(size);
std::vector<Buffer> shared_bufs(size);
std::vector<Buffer> local_bufs;
//
// This is an optimization. For small reduction sizes, it may be beneficial
// for a single warp to performance the entire reduction. No trips to shared
// memory and no cross warp synchronizations are required.
// The following code emits the reduction as follows:
//
// Allocate reduction vars v[i], i = 0..size-1
//
// for offset from WARP_SIZE to 1 by 2
//
// a <- load(v[i])
// b <- shuffle_down(load(v[i], offset))
// v[i] <- reduction(a, b)
//
// broadcast results from lane 0 to all other lanes and store
// the final reduction result to the proper location.
//
if (is_warp_reduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction";
//
// This is the index to the reduction variable, one reduction
// variable per warp. Local scope seems easier to reason without
// relying on a pattern match pass to fix it later.
Array<PrimExpr> zero_indices = {0};
for (size_t idx = 0; idx < size; ++idx) {
Array<PrimExpr> shape = {1};
Buffer buffer = decl_buffer(shape, types[idx], "red_buf" + std::to_string(idx));
Var buffer_var = buffer->data;
shared_buffer_vars[idx] = buffer_var;
shared_bufs[idx] = buffer;
PrimExpr pred = const_true(types[idx].lanes());
seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], zero_indices));
// Uses a local variable to store the shuffled data. Later
// on, an allocation will be built for this local variable.
local_bufs.push_back(decl_buffer(shape, types[idx], "t" + std::to_string(idx)));
}
// The mask for this reducer, as this reducer may sit inside
// a divergent control flow. Here it uses a variable to cache the current
// active channels.
//
DataType mask_dtype = DataType::UInt(32);
Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask");
{
PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
if (group_extent > 1) {
mask = mask & (((1 << reduce_extent) - 1) << (reduce_extent * group_index));
}
seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices));
// Push the buffer description. Later this will have an
// allocation built for it.
local_bufs.push_back(mask_buffer);
}
// Emit reductions within a warp.
int start_offset = 1;
while (start_offset * 2 < reduce_extent) {
start_offset *= 2;
}
for (int offset = start_offset; offset > 0; offset /= 2) {
// Load reduction values, no synchronization needed.
Array<PrimExpr> a, b;
for (size_t i = 0; i < size; ++i) {
Buffer shared_buf = shared_bufs[i];
BufferLoad val(shared_buf, zero_indices);
ICHECK_EQ(val->dtype, types[i]);
a.push_back(val);
// __shfl_*sync calls shall not appear in if_then_else expressions
// as this is causing extra divergency. E.g.
//
// v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0);
//
// behaves differently from
//
// int t = __shfl_sync(mask, v1, 0);
// v1 = (v2 < v3) ? v3 : t;
//
// The former may cause dead lock as there is a divergent
// branch with a warp sync call inside.
//
PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset);
Buffer local_buf = local_bufs[i];
Stmt s = BufferStore(local_buf, other, zero_indices);
seq.push_back(s);
BufferLoad load = BufferLoad(local_buf, zero_indices);
ICHECK_EQ(load->dtype, types[i]);
b.push_back(load);
}
// Do reductions.
Array<PrimExpr> ret = (*combiner)(a, b);
// Store the reduction result to itself.
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
Buffer buf = shared_bufs[i];
stores[i] = BufferStore(buf, ret[i], zero_indices);
}
// During the sub-warp reduction, values from inactive threads could be read,
// which is an undefined behavior according to the cuda document.
//
// In practise, the return value are usually 0, which does no harm to sum reduction.
// However, the result can be incorrect in max or prod reduction.
// Therefore an additional range check has to be performed to ensure the correctness.
if (offset * 2 > reduce_extent) {
PrimExpr cond = reduce_index + offset < reduce_extent;
seq.push_back(IfThenElse(cond, SeqStmt::Flatten(stores)));
} else {
seq.push_back(SeqStmt::Flatten(stores));
}
}
// Broadcast the reduction result from lane 0 to all other lanes.
// This avoids to emit predicated stores, as all threads are
// uniformly writting the same result.
//
for (size_t i = 0; i < size; ++i) {
Buffer buf = shared_bufs[i];
PrimExpr val = BufferLoad(buf, zero_indices);
ICHECK_EQ(val->dtype, types[i]);
PrimExpr splat =
WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, reduce_extent * group_index);
seq.push_back(BufferStore(buf, splat, zero_indices));
}
// Update existing allocations.
for (size_t i = 0; i < size; ++i) {
ICHECK(!load_remap_.count(buffers[i]->data.get()));
PrimExpr pred = const_true(types[i].lanes());
Buffer buf = shared_bufs[i];
PrimExpr val = BufferLoad(buf, zero_indices);
ICHECK_EQ(val->dtype, types[i]);
load_remap_[buffers[i]->data.get()] = val;
store_remap_[buffers[i].get()] = buf;
Array<PrimExpr> extents{PrimExpr(1)};
auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0));
alloc_remap_[buffers[i]->data.get()] = node;
var_remap_[buffers[i]->data.get()] = buf->data;
warp_allocs_.insert(node.get());
}
} else {
if (reduce_extent == 1) {
// special case, no reduction is needed.
std::vector<Stmt> stores;
for (size_t i = 0; i < size; ++i) {
stores.push_back(BufferStore(buffers[i], values[i], {0}));
}
return SeqStmt::Flatten(stores);
}
// This sync is necessary because there might be incomplete read of
// previous iteration on the same buffer.
seq.emplace_back(SyncThread("shared"));
for (size_t idx = 0; idx < size; ++idx) {
Buffer buffer = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx));
shared_bufs[idx] = buffer;
shared_buffer_vars[idx] = buffer->data;
PrimExpr pred = const_true(types[idx].lanes());
seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
{BufIndex(reduce_index, group_index, reduce_extent)}));
}
seq.emplace_back(SyncThread("shared"));
seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index,
reduce_extent, group_extent, contiguous_reduce_extent));
for (size_t idx = 0; idx < size; ++idx) {
ICHECK(!load_remap_.count(buffers[idx]->data.get()));
PrimExpr pred = const_true(types[idx].lanes());
BufferLoad load(shared_bufs[idx],
{BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)});
ICHECK_EQ(load->dtype, types[idx]);
load_remap_[buffers[idx]->data.get()] = load;
alloc_remap_[buffers[idx]->data.get()] =
Allocate(shared_bufs[idx]->data, types[idx],
{PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0));
var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
store_remap_[buffers[idx].get()] = shared_bufs[idx];
}
}
// Fix all local allocations as all statements are built.
Stmt body = SeqStmt::Flatten(seq);
for (Buffer buf : local_bufs) {
body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body);
new_storage_scopes_[buf->data.get()] = "local";
}
return body;
}
// make allreduce.
Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector<DataType>& types,
const Array<Buffer>& shared_bufs, PrimExpr reduce_index,
PrimExpr group_index, int reduce_extent, int group_extent,
int contiguous_reduce_extent) {
// Get next power of two
int reduce_align = 1;
while (reduce_extent > reduce_align) {
reduce_align = reduce_align << 1;
}
ICHECK_GT(reduce_align, 1);
std::vector<Stmt> seq;
size_t size = shared_bufs.size();
PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
// make reduction
auto fload = [&](int offset) {
Array<PrimExpr> a, b;
for (size_t i = 0; i < size; ++i) {
BufferLoad b_load(shared_bufs[i],
{BufIndex(reduce_index + offset, group_index, reduce_extent)});
ICHECK_EQ(b_load->dtype, types[i]);
b.push_back(b_load);
BufferLoad a_load(shared_bufs[i], {buf_index});
ICHECK_EQ(a_load->dtype, types[i]);
a.push_back(a_load);
}
Array<PrimExpr> ret = (*combiner)(a, b);
return ret;
};
auto fstore = [&](const Array<PrimExpr>& ret) {
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
stores[i] = BufferStore(shared_bufs[i], ret[i], {buf_index});
}
return SeqStmt::Flatten(stores);
};
auto freduce = [&](int offset) {
auto ret = fload(offset);
return fstore(ret);
};
// Step one, check for
if (reduce_align > reduce_extent) {
// reduction with the boundary condition
reduce_align = reduce_align >> 1;
PrimExpr cond = reduce_index < (reduce_extent - reduce_align);
seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread("shared"));
}
// normal synchronization
bool warp_align = group_extent == 1 || contiguous_reduce_extent % warp_size_ == 0;
while (reduce_align > contiguous_reduce_extent || reduce_align > warp_size_ || !warp_align) {
if (reduce_align == 1) {
break;
}
reduce_align = reduce_align >> 1;
PrimExpr cond = reduce_index < reduce_align;
seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread("shared"));
}
// in warp synchronization.
if (reduce_align > 1) {
PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1);
std::vector<Stmt> in_warp_seq;
while (reduce_align > 1) {
reduce_align = reduce_align >> 1;
// freduce can read/write to the same memory location. For
// example, with reduce_align of 4, threadIdx 3 reads from
// memory location 7 as threadIdx 7 is writing to it.
// Therefore, we need to separate out the load from the store
// with a memory barrier in-between. This isn't necessary for
// the earlier normal synchronization, because those are each
// protected by an if-statement. The if-statement is avoided
// here to reduce thread divergence.
auto loads = fload(reduce_align);
Array<Var> in_warp_local_vars;
for (auto expr : loads) {
Var var(
"w_" + std::to_string(reduce_align) + "_" + std::to_string(in_warp_local_vars.size()),
expr->dtype);
in_warp_local_vars.push_back(var);
}
std::vector<Stmt> in_let_statement;
in_let_statement.emplace_back(SyncThread("warp"));
in_let_statement.emplace_back(
fstore({in_warp_local_vars.begin(), in_warp_local_vars.end()}));
in_let_statement.emplace_back(SyncThread("warp"));
Stmt body = SeqStmt::Flatten(in_let_statement);
for (size_t i = 0; i < size; i++) {
body = LetStmt(in_warp_local_vars[i], loads[i], body);
}
in_warp_seq.push_back(body);
}
Stmt warp_body = SeqStmt::Flatten(in_warp_seq);
seq.emplace_back(IfThenElse(in_warp_cond, warp_body));
seq.emplace_back(SyncThread("shared"));
}
return SeqStmt::Flatten(seq);
}
// Flatten the thread index.
// Also return a warp number,
PrimExpr FlattenThread(const std::vector<ThreadEntry>& tvec, int* out_total_extent) {
int& total_extent = *out_total_extent;
total_extent = 1;
if (tvec.size() == 0) {
return make_zero(DataType::Int(32));
}
PrimExpr ret;
for (const ThreadEntry& e : tvec) {
if (ret.defined()) {
ret = ret + e.iv->var * total_extent;
} else {
ICHECK_EQ(total_extent, 1);
ret = e.iv->var;
}
total_extent *= e.extent;
}
return ret;
}
// The local buffer index.
PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) {
if (!is_zero(group_index)) {
return analyzer_.Simplify(group_index * reduce_extent + reduce_index);
} else {
return reduce_index;
}
}
// sync thread op.
static Stmt SyncThread(const std::string& sync) {
return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync)}));
}
// Emit warp shuffle calls.
PrimExpr WarpShuffle(const Op& op, Buffer mask_buffer, PrimExpr val, PrimExpr delta_or_lane) {
Array<PrimExpr> indices = {0};
PrimExpr mask = BufferLoad(mask_buffer, indices);
PrimExpr width = IntImm(DataType::Int(32), warp_size_);
Array<PrimExpr> args{mask, val, delta_or_lane, width, width};
return Call(val.dtype(), op, args);
}
// Check if we can use warp level reduction.
//
// Note: The ROCm backend will only have warp reductions for now.
// Also, the warp/wavefront size differs (64 on rocm, 32 on cuda).
bool is_warp_reduction(const std::vector<DataType>& types, int group_extent, int reduce_extent,
int contiguous_reduce_extent) const {
// Only cuda target supports warp reductions.
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) return false;
// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
(std::any_of(types.begin(), types.end(), [](DataType ty) {
if (ty.is_vector()) return true;
return ty.bits() != 32;
}))) {
return false;
}
// Supported types:
// {u}int, {u}long, {u}long long, float, double, half/half2
if (std::any_of(types.begin(), types.end(), [](DataType ty) {
if (ty.is_float16()) return ty.lanes() > 2;
if (ty.is_vector()) return true;
return ty.bytes() < 4 || ty.bytes() > 8;
})) {
return false;
}
if (thread_extents_.empty()) {
return false;
}
// reduce region must be contiguous.
if (contiguous_reduce_extent != reduce_extent) {
return false;
}
// whether reduce_extent and group_extent are vaild for warp reduction.
if (target_->kind->name == "rocm") {
return reduce_extent == warp_size_;
} else { // target_->kind->name == "cuda"
if (reduce_extent == 1) {
return false; // no need to warp reduce
} else {
if (warp_size_ % reduce_extent == 0) {
return true; // warp size is multiple of reduce extent
} else {
return group_extent == 1 && reduce_extent <= warp_size_;
}
}
}
}
// The target.
const TargetNode* target_ = nullptr;
// The warp size of the device.
int warp_size_{1};
// surrounding scope of thread extent.
std::vector<const AttrStmtNode*> thread_extents_;
std::vector<const CommReducerNode*> reduce_combiner_;
// The load remap
std::unordered_map<const VarNode*, PrimExpr> load_remap_;
// The store remap
std::unordered_map<const BufferNode*, Buffer> store_remap_;
// Allocate remap
std::unordered_map<const VarNode*, Stmt> alloc_remap_;
// BufferVar remap
std::unordered_map<const VarNode*, Var> var_remap_;
// Buffer remap
std::unordered_map<const BufferNode*, Buffer> buf_remap_;
// Allocate from warp reductions
std::unordered_set<const void*> warp_allocs_;
// Internal analyzer
arith::Analyzer analyzer_;
};
namespace transform {
Pass LowerThreadAllreduce() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute";
const TargetNode* target_node = target.as<TargetNode>();
ThreadAllreduceBuilder thread_all_reduce(target_node);
auto reduce_body = thread_all_reduce(n->body);
n->body =
UpdatePointerStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {});
}
TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce").set_body_typed(LowerThreadAllreduce);
} // namespace transform
} // namespace tir
} // namespace tvm