blob: 7c4a3c7f6ebd79bd72db2557962789a5a8709ccb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file schedule_postproc_rewrite_for_tensor_core.cc
*
* \brief Rewrite the Stmt generated by ScheduleOps
* to accomondate tensorcore.
*/
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/target/target_info.h>
#include <tvm/te/operation.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_map>
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
namespace te {
using namespace te;
using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
struct Tile {
int m{-1};
int n{-1};
int k{-1};
};
std::string simplify_name(std::string input) {
auto pos = input.find(".");
if (pos != std::string::npos) {
return input.substr(0, pos);
} else {
return input;
}
}
PrimExpr unpack_type_cast(const PrimExpr& input, const DataType& target_type) {
auto cast = input.as<CastNode>();
if (cast == nullptr) {
return input;
} else if (cast->dtype == target_type) {
return cast->value;
}
return PrimExpr();
}
// MMAMatcher matches C = Cast(A)*Cast(B)+C,
// where A & B are fp16/int8 local buffers,
// and C is fp32/int32 local buffer.
class MMAMatcher : public StmtVisitor {
public:
explicit MMAMatcher(Map<Tensor, Buffer> extern_buffer) {
for (auto kv : extern_buffer) {
BufferInfo bi;
bi.name = kv.second->name;
bi.dtype = kv.second->dtype;
bi.external = true;
buf_map_[kv.first] = bi;
}
}
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::pragma_tensor_core) {
tensor_core_on_ = true;
StmtVisitor::VisitStmt_(op);
} else if (op->attr_key == tir::attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
this->VisitStmt(op->body);
} else {
StmtVisitor::VisitStmt_(op);
}
}
void VisitStmt_(const ProducerStoreNode* op) final {
StmtVisitor::VisitStmt_(op);
auto it = buf_map_.find(Downcast<Tensor>(op->producer));
if (it == buf_map_.end()) {
return;
}
const BufferInfo& bi = it->second;
if (bi.released) {
return;
}
if (tensor_core_on_ && mma_sync_match_(op, bi)) {
matched_ = true;
}
}
void VisitStmt_(const ProducerRealizeNode* op) final {
auto key = Downcast<Tensor>(op->producer);
if (buf_map_.count(key)) {
if (!buf_map_.at(key).external) {
return;
}
this->VisitStmt(op->body);
} else {
BufferInfo bi;
bi.name = key->GetNameHint();
bi.dtype = key->dtype;
buf_map_[key] = bi;
this->VisitStmt(op->body);
buf_map_[key].released = true;
}
}
inline bool Matched() const { return matched_; }
friend class ScheduleAnalyser;
friend class BufferAnalyser;
private:
struct BufferInfo {
std::string name;
DataType dtype;
bool external{false};
bool released{false};
bool same_as(const BufferInfo& bi) {
if (this->dtype != bi.dtype) return false;
if (this->name != bi.name) return false;
if (this->external != bi.external) return false;
if (this->released != bi.released) return false;
return true;
}
};
// Check whether the storage scope is local
bool check_local_buffer_(const ProducerLoadNode* op, BufferInfo* bi) {
auto tensor = Downcast<Tensor>(op->producer);
auto it = storage_scope_.find(tensor.get());
if (it == storage_scope_.end()) {
return false;
}
const std::string& strkey = it->second;
if (strkey != "local") {
return false;
}
auto it1 = buf_map_.find(tensor);
if (it1 == buf_map_.end()) {
return false;
}
*bi = it1->second;
if (bi->released) {
return false;
}
return true;
}
// Do the pattern matching
bool mma_sync_match_(const ProducerStoreNode* op, BufferInfo store_buffer) {
auto* add = op->value.as<AddNode>();
if (add == nullptr) {
return false;
}
auto* load_c = add->a.as<ProducerLoadNode>();
BufferInfo buffer_c;
if (!check_local_buffer_(load_c, &buffer_c) || !buffer_c.same_as(store_buffer) ||
!(buffer_c.dtype == DataType::Float(32) || buffer_c.dtype == DataType::Int(32))) {
return false;
}
auto mul = unpack_type_cast(add->b, buffer_c.dtype).as<MulNode>();
if (mul == nullptr) {
return false;
}
auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype);
auto load_a = load_a_expr.as<ProducerLoadNode>();
BufferInfo buffer_a;
if (!check_local_buffer_(load_a, &buffer_a) ||
!(buffer_a.dtype == DataType::Float(16) || buffer_a.dtype == DataType::Int(8) ||
buffer_a.dtype == DataType::UInt(8) || buffer_a.dtype == DataType::Int(4) ||
buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) {
return false;
}
auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype);
auto load_b = load_b_expr.as<ProducerLoadNode>();
BufferInfo buffer_b;
if (!check_local_buffer_(load_b, &buffer_b) ||
!(buffer_b.dtype == DataType::Float(16) || buffer_b.dtype == DataType::Int(8) ||
buffer_b.dtype == DataType::UInt(8) || buffer_b.dtype == DataType::Int(4) ||
buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) {
return false;
}
frag_reg_.insert(buffer_c.name);
frag_reg_.insert(buffer_a.name);
frag_reg_.insert(buffer_b.name);
buf_name_.insert(std::make_pair(load_a, buffer_a.name));
buf_name_.insert(std::make_pair(load_b, buffer_b.name));
mma_sync_.insert(std::make_pair(op, Array<PrimExpr>{load_a_expr, load_b_expr, add->a}));
return true;
}
std::unordered_map<Tensor, BufferInfo> buf_map_;
std::unordered_map<const Object*, std::string> storage_scope_;
std::unordered_map<const ProducerStoreNode*, Array<PrimExpr>> mma_sync_;
std::unordered_map<const Object*, std::string> buf_name_;
std::unordered_set<std::string> frag_reg_;
bool matched_{false};
bool tensor_core_on_{false};
};
// BodyVisitor visits the body stmt of original ComputeOp
// to get the access indices of input matrices,
// if it is recognized as matrix multiply.
class BodyVisitor : public StmtExprVisitor {
public:
BodyVisitor() {}
void VisitExpr_(const ReduceNode* op) final {
auto* comm_add = op->combiner->result[0].as<AddNode>();
if (comm_add == nullptr || op->combiner->result.size() > 1) {
return;
}
for (PrimExpr source : op->source) {
auto mul_0 = unpack_type_cast(source, DataType::Float(32)).as<MulNode>();
auto mul_1 = unpack_type_cast(source, DataType::Int(32)).as<MulNode>();
if (mul_0 == nullptr && mul_1 == nullptr) {
continue;
}
tensorcore_candidate_ = true;
StmtExprVisitor::VisitExpr(source);
}
}
void VisitExpr_(const ProducerLoadNode* op) final {
StmtExprVisitor::VisitExpr_(op);
args_.insert(std::make_pair(op->producer->GetNameHint(), op->indices));
}
friend class ScheduleAnalyser;
private:
std::unordered_map<std::string, Array<PrimExpr>> args_;
bool tensorcore_candidate_{false};
};
// ScheduleAnalyser figures out matrix_a/matrix_b and row_major/col_major
class ScheduleAnalyser {
public:
explicit ScheduleAnalyser(const MMAMatcher& mma_matcher)
: mma_sync_(mma_matcher.mma_sync_), buf_name_(mma_matcher.buf_name_) {}
bool MatrixIdentify(Schedule schedule) {
// TODO(minmin): handle the case where MatMul is not the output stage
for (Operation output : schedule->outputs) {
const ComputeOpNode* compute = output.as<ComputeOpNode>();
if (compute == nullptr) {
// Not a ComputeOp
continue;
}
auto axis = compute->axis;
auto reduce_axis = compute->reduce_axis;
if (axis.size() < 2 || reduce_axis.size() != 1) {
continue;
}
const VarNode* axis_var[2];
const VarNode* reduce_axis_var;
axis_var[0] = axis[axis.size() - 2]->var.as<VarNode>();
axis_var[1] = axis[axis.size() - 1]->var.as<VarNode>();
reduce_axis_var = reduce_axis[0]->var.as<VarNode>();
BodyVisitor body_visitor;
for (PrimExpr expr : compute->body) {
body_visitor(expr);
}
if (!body_visitor.tensorcore_candidate_) {
continue;
}
for (auto iter : body_visitor.args_) {
auto name = iter.first;
auto args = iter.second;
if (args.size() < 2) {
continue;
}
const VarNode* var0 = args[args.size() - 2].as<VarNode>();
const VarNode* var1 = args[args.size() - 1].as<VarNode>();
if (var0 == nullptr || var1 == nullptr) {
continue;
}
std::string matrix_abc, major;
if (var0 == reduce_axis_var && var1 == axis_var[1]) {
matrix_abc = "matrix_a";
major = "col_major";
} else if (var0 == reduce_axis_var && var1 == axis_var[0]) {
matrix_abc = "matrix_b";
major = "row_major";
} else if (var0 == axis_var[1] && var1 == reduce_axis_var) {
matrix_abc = "matrix_a";
major = "row_major";
} else if (var0 == axis_var[0] && var1 == reduce_axis_var) {
matrix_abc = "matrix_b";
major = "col_major";
}
matrix_abc_.insert(std::make_pair(name, matrix_abc));
matrix_major_.insert(std::make_pair(name, major));
}
matrix_abc_.insert(std::make_pair(compute->name, "accumulator"));
matrix_major_.insert(std::make_pair(compute->name, "col_major"));
}
for (auto& mma_sync : mma_sync_) {
auto& operands = mma_sync.second;
auto* load_a = operands[0].as<CallNode>();
auto* load_b = operands[1].as<CallNode>();
auto input0 = simplify_name(buf_name_.find(load_a)->second);
auto input1 = simplify_name(buf_name_.find(load_b)->second);
auto it0 = matrix_abc_.find(input0);
auto it1 = matrix_abc_.find(input1);
if (it0 == matrix_abc_.end() || it1 == matrix_abc_.end()) {
return false;
}
if (it0->second == "matrix_a" && it1->second == "matrix_b") {
return true;
} else if (it0->second == "matrix_b" && it1->second == "matrix_a") {
mma_sync.second = Array<PrimExpr>{operands[1], operands[0], operands[2]};
} else {
return false;
}
}
return true;
}
friend class BufferAnalyser;
friend class TensorCoreIRMutator;
private:
std::unordered_map<std::string, std::string> matrix_abc_;
std::unordered_map<std::string, std::string> matrix_major_;
std::unordered_map<const ProducerStoreNode*, Array<PrimExpr>> mma_sync_;
std::unordered_map<const Object*, std::string> buf_name_;
};
// IndexVisitor visits access index of fragment
// to record variable for loop scaling
class IndexVisitor : public StmtExprVisitor {
public:
IndexVisitor() {}
void VisitExpr_(const VarNode* op) final {
loop_scaling_.insert(std::make_pair(op, scaling_factor_));
}
friend class BufferAnalyser;
friend class TensorCoreIRMutator;
private:
std::unordered_map<const VarNode*, unsigned> loop_scaling_;
unsigned scaling_factor_{0};
};
// BufferAnalyser gets buffer info,
// e.g. thread tile and warp tile, for TensorCore CodeGen
class BufferAnalyser : public StmtExprVisitor {
public:
explicit BufferAnalyser(Map<Tensor, Buffer> extern_buffer,
const ScheduleAnalyser& schedule_analyser, const MMAMatcher& mma_matcher)
: matrix_abc_(schedule_analyser.matrix_abc_),
matrix_major_(schedule_analyser.matrix_major_),
frag_reg_(mma_matcher.frag_reg_) {
for (auto kv : extern_buffer) {
BufferInfo bi;
bi.name = kv.second->name;
bi.dtype = kv.second->dtype;
bi.strides = kv.second->strides;
bi.shape = kv.second->shape;
bi.external = true;
buf_map_[kv.first] = bi;
}
}
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::thread_extent) {
if (const IntImmNode* value = op->value.as<IntImmNode>()) {
thread_extent_.insert(
std::make_pair(op->node.as<IterVarNode>()->var->name_hint, value->value));
}
StmtExprVisitor::VisitStmt_(op);
} else if (op->attr_key == tir::attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
this->VisitStmt(op->body);
} else if (op->attr_key == tir::attr::buffer_dim_align) {
te::Tensor tensor = Downcast<te::Tensor>(op->node);
const CallNode* tuple = op->value.as<CallNode>();
CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
auto& vinfo = dim_align_[tensor];
size_t dim = tuple->args[0].as<IntImmNode>()->value;
if (dim >= vinfo.size()) {
vinfo.resize(dim + 1);
}
vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
this->VisitStmt(op->body);
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
void VisitStmt_(const ProducerStoreNode* op) final {
StmtExprVisitor::VisitStmt_(op);
auto key = Downcast<Tensor>(op->producer);
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key->GetNameHint();
const BufferInfo& bi = it->second;
CHECK(!bi.released) << "Read a buffer that is already out of scope";
if (matrix_abc_.count(key->GetNameHint())) {
if (bi.shape.size() < 2) {
invalid_ = true;
return;
}
for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) {
const IntImmNode* shape = bi.shape[i].as<IntImmNode>();
if (shape == nullptr || shape->value % 16 != 0) {
invalid_ = true;
return;
}
}
}
Array<PrimExpr> strides;
if (bi.strides.size() > 0) {
strides = bi.strides;
} else {
for (size_t i = 1; i < bi.shape.size(); ++i) {
PrimExpr stride = IntImm(DataType::Int(32), 1);
for (size_t j = bi.shape.size() - 1; j >= i; --j) {
stride = Mul(stride, bi.shape[j]);
}
strides.push_back(stride);
}
strides.push_back(make_const(DataType::Int(32), 1));
}
strides_.insert(std::make_pair(key->GetNameHint(), strides));
if (frag_reg_.count(bi.name)) {
PrimExpr dst = ProducerLoad(op->producer, op->indices);
frag_load_.insert(std::make_pair(op, dst));
auto rel_index = bi.RelIndex(op->indices);
if (op->indices.size() < 2) {
invalid_ = true;
return;
}
std::vector<int> tile_size;
for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) {
index_visitor.scaling_factor_ = 16;
if (const IntImmNode* shape = bi.shape[i].as<IntImmNode>()) {
tile_size.push_back(shape->value);
index_visitor.scaling_factor_ = shape->value;
} else {
invalid_ = true;
return;
}
auto index = rel_index[i];
auto simplified_index = analyzer_.Simplify(index);
index_visitor(simplified_index);
}
std::string input_name = simplify_name(bi.name);
auto it = matrix_abc_.find(input_name);
auto it2 = matrix_major_.find(input_name);
bool ret = true;
if (it != matrix_abc_.end() && it2 != matrix_major_.end()) {
if (it->second == "matrix_a" && it2->second == "col_major") {
ret &= assign_or_check_(&thread_tile_.m, tile_size[0]);
ret &= assign_or_check_(&thread_tile_.k, tile_size[1]);
}
if (it->second == "matrix_a" && it2->second == "row_major") {
ret &= assign_or_check_(&thread_tile_.k, tile_size[0]);
ret &= assign_or_check_(&thread_tile_.m, tile_size[1]);
}
if (it->second == "matrix_b" && it2->second == "col_major") {
ret &= assign_or_check_(&thread_tile_.k, tile_size[0]);
ret &= assign_or_check_(&thread_tile_.n, tile_size[1]);
}
if (it->second == "matrix_b" && it2->second == "row_major") {
ret &= assign_or_check_(&thread_tile_.n, tile_size[0]);
ret &= assign_or_check_(&thread_tile_.k, tile_size[1]);
}
if (it->second == "accumulator") {
ret &= assign_or_check_(&thread_tile_.m, tile_size[0]);
ret &= assign_or_check_(&thread_tile_.n, tile_size[1]);
}
if (!ret) {
invalid_ = true;
return;
}
}
}
const ProducerLoadNode* value = op->value.as<ProducerLoadNode>();
// TODO(tvm-team): string matching is dangerous, consider other means.
if (value != nullptr && frag_reg_.count(value->producer->GetNameHint())) {
PrimExpr dst = ProducerLoad(op->producer, op->indices);
frag_store_.insert(std::make_pair(op, dst));
}
}
void VisitExpr_(const ProducerLoadNode* op) final {
StmtExprVisitor::VisitExpr_(op);
auto tensor = Downcast<Tensor>(op->producer);
auto it = buf_map_.find(tensor);
CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << tensor->GetNameHint();
const BufferInfo& bi = it->second;
CHECK(!bi.released) << "Read a buffer that is already out of scope";
if (matrix_abc_.count(tensor->op->name)) {
if (bi.shape.size() < 2) {
invalid_ = true;
return;
}
for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) {
const IntImmNode* shape = bi.shape[i].as<IntImmNode>();
if (shape == nullptr || shape->value % 16 != 0) {
invalid_ = true;
return;
}
}
}
Array<PrimExpr> strides;
if (bi.strides.size() > 0) {
strides = bi.strides;
} else {
for (size_t i = 1; i < bi.shape.size(); ++i) {
PrimExpr stride = IntImm(DataType::Int(32), 1);
for (size_t j = bi.shape.size() - 1; j >= i; --j) {
stride = Mul(stride, bi.shape[j]);
}
strides.push_back(stride);
}
strides.push_back(make_const(DataType::Int(32), 1));
}
strides_.insert(std::make_pair(tensor->GetNameHint(), strides));
if (!frag_reg_.count(bi.name)) {
return;
}
auto rel_index = bi.RelIndex(op->indices);
if (op->indices.size() < 2) {
invalid_ = true;
return;
}
for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) {
index_visitor.scaling_factor_ = 16;
if (const IntImmNode* shape = bi.shape[i].as<IntImmNode>()) {
index_visitor.scaling_factor_ = shape->value;
}
auto index = rel_index[i];
auto simplified_index = analyzer_.Simplify(index);
index_visitor(simplified_index);
}
}
void VisitStmt_(const ProducerRealizeNode* op) final {
auto key = Downcast<Tensor>(op->producer);
if (buf_map_.count(key)) {
CHECK(buf_map_.at(key).external);
this->VisitStmt(op->body);
} else {
// create a buffer entry
BufferInfo bi;
bi.bounds = op->bounds;
Array<PrimExpr> shape;
for (auto r : bi.bounds) {
shape.push_back(r->extent);
}
Array<PrimExpr> strides;
if (dim_align_.count(key) != 0 && shape.size() != 0) {
std::vector<PrimExpr> rstrides;
const std::vector<DimAlignInfo>& avec = dim_align_[key];
int first_dim = 0;
PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
for (size_t i = shape.size(); i != 0; --i) {
size_t dim = i - 1;
if (dim < avec.size() && avec[dim].align_factor != 0) {
PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
stride = analyzer_.Simplify(stride);
}
rstrides.push_back(stride);
stride = stride * shape[dim];
}
strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
}
bi.name = key->GetNameHint();
bi.dtype = key->dtype;
bi.strides = strides;
bi.shape = shape;
buf_map_[key] = bi;
this->VisitStmt(op->body);
buf_map_[key].released = true;
}
}
// Derive warp tile from thread tile,
// and check whether it is qualified for TensorCore.
bool QualifiedForTensorCore() {
if (invalid_) {
return false;
}
auto itx = thread_extent_.find("threadIdx.x");
if (itx == thread_extent_.end()) {
return false;
}
int warp_threads_x = itx->second;
warp_tile_.m = warp_threads_x * thread_tile_.m;
warp_threads_y_ = 32 / warp_threads_x;
auto ity = thread_extent_.find("threadIdx.y");
if (ity == thread_extent_.end()) {
return false;
}
if (ity->second < warp_threads_y_ || ity->second % warp_threads_y_ != 0) {
return false;
}
warp_tile_.n = warp_threads_y_ * thread_tile_.n;
warp_tile_.k = thread_tile_.k;
return supported_warp_tile_();
}
friend class TensorCoreIRMutator;
private:
struct DimAlignInfo {
int align_factor{0};
int align_offset{0};
};
struct BufferInfo {
std::string name;
DataType dtype;
Array<PrimExpr> strides;
Array<PrimExpr> shape;
Region bounds;
bool external{false};
bool released{false};
inline Array<PrimExpr> RelIndex(Array<PrimExpr> args) const {
if (bounds.size() != 0) {
Array<PrimExpr> index;
CHECK_EQ(bounds.size(), args.size());
for (size_t i = 0; i < bounds.size(); ++i) {
index.push_back(args[i] - bounds[i]->min);
}
return index;
} else {
return args;
}
}
};
bool assign_or_check_(int* dst, int src) {
if (*dst <= 0) {
*dst = src;
return true;
}
if (*dst == src) {
return true;
}
return false;
}
bool supported_warp_tile_() {
if (warp_tile_.m == 16 && warp_tile_.n == 16 && warp_tile_.k == 16) {
return true;
}
if (warp_tile_.m == 8 && warp_tile_.n == 32 && warp_tile_.k == 16) {
return true;
}
if (warp_tile_.m == 32 && warp_tile_.n == 8 && warp_tile_.k == 16) {
return true;
}
if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 32) {
return true;
}
if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 128) {
return true;
}
return false;
}
std::unordered_map<Tensor, BufferInfo> buf_map_;
std::unordered_map<Tensor, std::vector<DimAlignInfo>> dim_align_;
std::unordered_map<const Object*, std::string> storage_scope_;
std::unordered_map<std::string, std::string> matrix_abc_;
std::unordered_map<std::string, std::string> matrix_major_;
std::unordered_set<std::string> frag_reg_;
std::unordered_map<std::string, Array<PrimExpr>> strides_;
std::unordered_map<const ProducerStoreNode*, PrimExpr> frag_load_;
std::unordered_map<const ProducerStoreNode*, PrimExpr> frag_store_;
std::unordered_map<std::string, int> thread_extent_;
IndexVisitor index_visitor;
Tile warp_tile_;
Tile thread_tile_;
arith::Analyzer analyzer_;
int warp_threads_y_{-1};
bool invalid_{false};
};
// ThreadIdxMutator does the thread index unification inside a warp
class ThreadIdxMutator : public StmtExprMutator {
public:
explicit ThreadIdxMutator(PrimExpr warp_y) : warp_y_(warp_y) {}
PrimExpr VisitExpr_(const VarNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<VarNode>();
if (op != nullptr) {
if (op->name_hint == "threadIdx.x") {
PrimExpr zero = IntImm(DataType::Int(32), 0);
return zero;
}
if (op->name_hint == "threadIdx.y") {
PrimExpr div = Div(expr, warp_y_);
PrimExpr mul = Mul(div, warp_y_);
return mul;
}
}
return expr;
}
private:
PrimExpr warp_y_;
};
// TensorCoreIRMutator mutates the AST for TensorCore CodeGen
// based on tensor core intrinsics
class TensorCoreIRMutator : public StmtExprMutator {
public:
explicit TensorCoreIRMutator(const ScheduleAnalyser& schedule_analyser,
const BufferAnalyser& buffer_analyser)
: matrix_abc_(schedule_analyser.matrix_abc_),
matrix_major_(schedule_analyser.matrix_major_),
mma_sync_(schedule_analyser.mma_sync_),
strides_(buffer_analyser.strides_),
frag_reg_(buffer_analyser.frag_reg_),
loop_scaling_(buffer_analyser.index_visitor.loop_scaling_),
frag_load_(buffer_analyser.frag_load_),
frag_store_(buffer_analyser.frag_store_),
warp_tile_(buffer_analyser.warp_tile_),
warp_threads_y_(buffer_analyser.warp_threads_y_) {}
Stmt VisitStmt_(const ProducerRealizeNode* op) final {
auto key = Downcast<Tensor>(op->producer);
bounds_[key] = op->bounds;
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ProducerRealizeNode>();
if (op != nullptr) {
if (!frag_reg_.count(key->GetNameHint())) {
return stmt;
}
auto new_extents = get_tile_size_(simplify_name(key->GetNameHint()));
Region new_bounds;
for (size_t i = 0; i < op->bounds.size() - 2; ++i) {
new_bounds.push_back(op->bounds[i]);
}
CHECK_GE(op->bounds.size(), 2) << "Less than 2 dimensions for matrix " << key->GetNameHint();
new_bounds.push_back(
Range::FromMinExtent(op->bounds[op->bounds.size() - 2]->min, new_extents[0]));
new_bounds.push_back(
Range::FromMinExtent(op->bounds[op->bounds.size() - 1]->min, new_extents[1]));
return ProducerRealize(op->producer, new_bounds, op->condition, op->body);
}
return stmt;
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
if (op->attr_key == tir::attr::realize_scope) {
auto node = op->node.as<te::OperationNode>();
if (node != nullptr) {
if (!frag_reg_.count(node->name)) {
return stmt;
}
auto it = matrix_abc_.find(simplify_name(node->name));
CHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name;
auto matrix_abc = tvm::tir::StringImm("wmma." + it->second);
Stmt body = this->VisitStmt(op->body);
return AttrStmt(op->node, op->attr_key, matrix_abc, body);
}
}
return stmt;
}
Stmt VisitStmt_(const ProducerStoreNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
auto it = mma_sync_.find(op);
if (it != mma_sync_.end()) {
const auto& operands = it->second;
PrimExpr a = operands[0];
auto ca = a.as<ProducerLoadNode>();
PrimExpr b = operands[1];
auto cb = b.as<ProducerLoadNode>();
PrimExpr c = operands[2];
auto cc = c.as<ProducerLoadNode>();
ObjectPtr<BufferNode> buffer_node_a = make_object<BufferNode>();
ObjectPtr<BufferNode> buffer_node_b = make_object<BufferNode>();
ObjectPtr<BufferNode> buffer_node_c = make_object<BufferNode>();
auto mma_sync_call = [&buffer_node_a, &buffer_node_b, &ca, &cb](const Buffer& buffer) {
Buffer buffer_a(buffer_node_a);
Buffer buffer_b(buffer_node_b);
if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) {
return Evaluate(
Call(DataType::Handle(), builtin::tvm_bmma_sync(),
{buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}));
} else {
return Evaluate(
Call(DataType::Handle(), builtin::tvm_mma_sync(),
{buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}));
}
};
auto call_add_c = [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer& buffer) {
return add_buffer_bind_scope_(cc, buffer_node_c, mma_sync_call);
};
auto call_add_b = [this, &cb, &buffer_node_b, &call_add_c](const Buffer& buffer) {
return add_buffer_bind_scope_(cb, buffer_node_b, call_add_c);
};
return add_buffer_bind_scope_(ca, buffer_node_a, call_add_b);
}
auto it2 = frag_load_.find(op);
if (it2 != frag_load_.end()) {
PrimExpr dst = it2->second;
if (op->value.as<FloatImmNode>() != nullptr || op->value.as<IntImmNode>() != nullptr) {
auto pload = dst.as<ProducerLoadNode>();
auto fill_fragment_call = [this, &op](const Buffer& buffer) {
return Evaluate(Call(DataType::Handle(), builtin::tvm_fill_fragment(),
{buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
buffer->elem_offset, op->value}));
};
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
return add_buffer_bind_scope_(pload, buffer_node, fill_fragment_call);
}
const ProducerLoadNode* value = op->value.as<ProducerLoadNode>();
CHECK(value != nullptr) << "Can only load fragment from a buffer";
auto it = strides_.find(value->producer->GetNameHint());
CHECK(it != strides_.end()) << "Cannot find stride for " << value->producer->GetNameHint();
auto strides = it->second;
CHECK_GE(strides.size(), 2);
PrimExpr stride = strides[strides.size() - 2];
// thread index unification inside a warp
PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
PrimExpr mutated_value = thread_idx_mutator(op->value);
// TODO(tvm-team) The extern function name seems to be a hack.
PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value});
auto pload = dst.as<ProducerLoadNode>();
PrimExpr matrix_major;
auto iter2 = matrix_major_.find(simplify_name(pload->producer->GetNameHint()));
CHECK(iter2 != matrix_major_.end())
<< "Can not determine matrix major for " << pload->producer->GetNameHint();
if (iter2->second == "col_major") {
matrix_major = StringImm("col_major");
} else if (iter2->second == "row_major") {
matrix_major = StringImm("row_major");
} else {
LOG(FATAL) << "invalid matrix major for " << pload->producer->GetNameHint();
}
auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) {
return Evaluate(Call(DataType::Handle(), builtin::tvm_load_matrix_sync(),
{buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
buffer->elem_offset, src, stride, matrix_major}));
};
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
return add_buffer_bind_scope_(pload, buffer_node, load_matrix_call);
}
auto it3 = frag_store_.find(op);
if (it3 != frag_store_.end()) {
auto it = strides_.find(op->producer->GetNameHint());
CHECK(it != strides_.end()) << "Cannot find stride for " << op->producer->GetNameHint();
auto strides = it->second;
CHECK_GE(strides.size(), 2);
PrimExpr stride = strides[strides.size() - 2];
PrimExpr dst = it3->second;
// thread index unification inside a warp
PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
dst = thread_idx_mutator(dst);
dst = Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst});
auto pload = op->value.as<ProducerLoadNode>();
auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) {
return Evaluate(Call(DataType::Handle(), builtin::tvm_store_matrix_sync(),
{buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
buffer->elem_offset, dst, stride, StringImm("col_major")}));
};
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
return add_buffer_bind_scope_(pload, buffer_node, store_matrix_call);
}
return stmt;
}
Stmt VisitStmt_(const ForNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
if (op != nullptr) {
auto it = loop_scaling_.find(op->loop_var.get());
if (it != loop_scaling_.end()) {
int scale_factor = it->second;
int scaled_extent_value = 1;
if (const IntImmNode* ori_extent = op->extent.as<IntImmNode>()) {
int ori_extent_value = ori_extent->value;
scaled_extent_value = ori_extent_value / scale_factor;
}
PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value);
stmt = For(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api, op->body);
}
}
return stmt;
}
private:
Array<PrimExpr> get_tile_size_(const std::string& name) {
auto it = matrix_abc_.find(name);
auto it2 = matrix_major_.find(name);
CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end())
<< "Cannot find matrix info for " << name;
PrimExpr size0 = make_const(DataType::Int(32), 16);
PrimExpr size1 = make_const(DataType::Int(32), 16);
if (it->second == "matrix_a" && it2->second == "col_major") {
size0 = make_const(DataType::Int(32), warp_tile_.k);
size1 = make_const(DataType::Int(32), warp_tile_.m);
}
if (it->second == "matrix_a" && it2->second == "row_major") {
size0 = make_const(DataType::Int(32), warp_tile_.m);
size1 = make_const(DataType::Int(32), warp_tile_.k);
}
if (it->second == "matrix_b" && it2->second == "row_major") {
size0 = make_const(DataType::Int(32), warp_tile_.k);
size1 = make_const(DataType::Int(32), warp_tile_.n);
}
if (it->second == "matrix_b" && it2->second == "col_major") {
size0 = make_const(DataType::Int(32), warp_tile_.n);
size1 = make_const(DataType::Int(32), warp_tile_.k);
}
if (it->second == "matrix_c") {
size0 = make_const(DataType::Int(32), warp_tile_.n);
size1 = make_const(DataType::Int(32), warp_tile_.m);
}
Array<PrimExpr> tile_size = {size0, size1};
return tile_size;
}
Stmt add_buffer_bind_scope_(const ProducerLoadNode* pload,
const ObjectPtr<BufferNode>& buffer_node,
const std::function<Stmt(const Buffer& buffer)>& call_back) {
auto tensor = Downcast<Tensor>(pload->producer);
auto it = bounds_.find(tensor);
CHECK(it != bounds_.end());
Array<PrimExpr> min_bound;
for (auto i : it->second) {
min_bound.push_back(i->min);
}
CHECK_GE(it->second.size(), 2);
Array<PrimExpr> shape;
for (size_t i = 0; i < it->second.size() - 2; ++i) {
shape.push_back(it->second[i]->extent);
}
auto tile_size = get_tile_size_(simplify_name(tensor->op->name));
shape.push_back(tile_size[0]);
shape.push_back(tile_size[1]);
Array<PrimExpr> strides;
for (size_t i = 1; i < shape.size(); ++i) {
PrimExpr stride = IntImm(DataType::Int(32), 1);
for (size_t j = shape.size() - 1; j >= i; --j) {
stride = Mul(stride, shape[j]);
}
strides.push_back(stride);
}
strides.push_back(make_const(DataType::Int(32), 1));
PrimExpr elem_offset = IntImm(DataType::Int(32), 0);
CHECK_EQ(pload->indices.size(), min_bound.size());
for (size_t i = 0; i < min_bound.size(); i++) {
elem_offset = Add(elem_offset, Mul(strides[i], Sub(pload->indices[i], min_bound[i])));
}
auto it2 = matrix_abc_.find(simplify_name(tensor->op->name));
CHECK(it2 != matrix_abc_.end()) << "Cannot find matrix info for " << tensor->op->name;
buffer_node->data = Var(tensor->op->name, DataType::Handle());
buffer_node->name = tensor->op->name;
buffer_node->scope = "wmma." + it2->second;
buffer_node->dtype = tensor->dtype;
buffer_node->strides = strides;
buffer_node->shape = shape;
buffer_node->data_alignment = 1;
buffer_node->elem_offset = analyzer_.Simplify(elem_offset);
buffer_node->offset_factor = 1;
Buffer buffer(buffer_node);
Array<PrimExpr> args;
for (size_t i = 0; i < pload->indices.size(); ++i) {
args.push_back(pload->indices[i]);
args.push_back(shape[i]);
}
auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args);
Array<ObjectRef> node = {buffer, tensor};
return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer));
}
std::unordered_map<std::string, std::string> matrix_abc_;
std::unordered_map<std::string, std::string> matrix_major_;
std::unordered_map<const ProducerStoreNode*, Array<PrimExpr>> mma_sync_;
std::unordered_map<std::string, Array<PrimExpr>> strides_;
std::unordered_set<std::string> frag_reg_;
std::unordered_map<const VarNode*, unsigned> loop_scaling_;
std::unordered_map<const ProducerStoreNode*, PrimExpr> frag_load_;
std::unordered_map<const ProducerStoreNode*, PrimExpr> frag_store_;
std::unordered_map<Tensor, Region> bounds_;
arith::Analyzer analyzer_;
Tile warp_tile_;
int warp_threads_y_{-1};
};
Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule,
Map<Tensor, Buffer> extern_buffer) {
// Check if current lower target is CUDA
auto target = tvm::Target::Current(true);
if (target.defined() && target->kind->name != "cuda") {
return stmt;
}
// Check if current runtime support GPU CUDA
TVMContext ctx{kDLGPU, 0};
auto api = tvm::runtime::DeviceAPI::Get(ctx, true);
if (api == nullptr) {
return stmt;
}
MMAMatcher mma_matcher(extern_buffer);
mma_matcher(stmt);
if (!mma_matcher.Matched()) {
return stmt;
}
ScheduleAnalyser schedule_analyser(mma_matcher);
if (!schedule_analyser.MatrixIdentify(schedule)) {
return stmt;
}
BufferAnalyser buffer_analyser(extern_buffer, schedule_analyser, mma_matcher);
buffer_analyser(stmt);
if (!buffer_analyser.QualifiedForTensorCore()) {
return stmt;
}
return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt));
}
TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore")
.set_body_typed([](Stmt stmt, Schedule schedule, Map<te::Tensor, Buffer> extern_buffer) {
return SchedulePostProcRewriteForTensorCore(stmt, schedule, extern_buffer);
});
} // namespace te
} // namespace tvm