blob: 1b3b3c44ff9ca2d8e7135dfc918fbc0818f2dd34 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Infer TensorCore metadata from tensor intrinsic.
* \file tensorcore_fragment.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "../../runtime/thread_storage_scope.h"
#include "ir_util.h"
#include "storage_access.h"
namespace tvm {
namespace tir {
// Get fragment information from tensor intrinsics
class FragmentGetter : public StmtExprVisitor {
public:
// fragment metadata
struct FragmentInfo {
// fragment shape
int m, n, k;
// fragment layout (row-major or column-major)
std::string layout;
FragmentInfo() = default;
FragmentInfo(int _m, int _n, int _k, const std::string& _layout)
: m(_m), n(_n), k(_k), layout(_layout) {}
};
void VisitExpr_(const CallNode* op) final {
StmtExprVisitor::VisitExpr_(op);
if (op->op.same_as(builtin::tvm_load_matrix_sync()) ||
op->op.same_as(builtin::tvm_store_matrix_sync())) {
// Get shape and layout information from load and store intrinsic
CHECK_EQ(op->args.size(), 8U);
const VarNode* buffer_var = op->args[0].as<VarNode>();
CHECK(buffer_var);
// Get shape
const IntImmNode* m = op->args[1].as<IntImmNode>();
const IntImmNode* n = op->args[2].as<IntImmNode>();
const IntImmNode* k = op->args[3].as<IntImmNode>();
const StringImmNode* layout = op->args[7].as<StringImmNode>();
CHECK(m);
CHECK(n);
CHECK(k);
CHECK(layout);
std::string scope = scopes[buffer_var];
if (fragments.count(buffer_var)) {
// check if the fragment has met before
FragmentInfo info = fragments[buffer_var];
CHECK_EQ(m->value, info.m);
CHECK_EQ(n->value, info.n);
CHECK_EQ(k->value, info.k);
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
CHECK_EQ(layout->value, info.layout);
}
} else {
// store metadata
FragmentInfo info;
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
info = FragmentInfo(m->value, n->value, k->value, layout->value);
} else if (scope == "wmma.accumulator") {
info = FragmentInfo(m->value, n->value, k->value, "");
}
fragments[buffer_var] = info;
}
} else if (op->op.same_as(builtin::tvm_fill_fragment())) {
// Get shape information from fill intrinsic
CHECK_EQ(op->args.size(), 6U);
const VarNode* buffer_var = op->args[0].as<VarNode>();
CHECK(buffer_var);
// Get shape
const IntImmNode* m = op->args[1].as<IntImmNode>();
const IntImmNode* n = op->args[2].as<IntImmNode>();
const IntImmNode* k = op->args[3].as<IntImmNode>();
CHECK(m);
CHECK(n);
CHECK(k);
std::string scope = scopes[buffer_var];
// Only wmma.accumulator can use tvm_fill_fragment
CHECK_EQ(scope, "wmma.accumulator");
if (fragments.count(buffer_var)) {
FragmentInfo info = fragments[buffer_var];
CHECK_EQ(m->value, info.m);
CHECK_EQ(n->value, info.n);
CHECK_EQ(k->value, info.k);
} else {
FragmentInfo info(m->value, n->value, k->value, "");
fragments[buffer_var] = info;
}
}
}
// Get memory scope
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::storage_scope) {
const VarNode* buffer = op->node.as<VarNode>();
CHECK(buffer);
scopes[buffer] = op->value.as<StringImmNode>()->value;
}
StmtExprVisitor::VisitStmt_(op);
}
// Memory scope for allocations
std::unordered_map<const VarNode*, std::string> scopes;
// Fragment metadata for all fragments
std::unordered_map<const VarNode*, FragmentInfo> fragments;
};
// Check shape of fragment making sure it is a valid shape for tvm_mma_sync
class FragmentChecker : public StmtExprVisitor {
public:
explicit FragmentChecker(const FragmentGetter& getter) : fragment_getter(getter) {}
void VisitExpr_(const CallNode* op) final {
StmtExprVisitor::VisitExpr_(op);
// Check shape when calling tvm_mma_sync
if (op->op.same_as(builtin::tvm_mma_sync()) || op->op.same_as(builtin::tvm_bmma_sync())) {
CHECK_EQ(op->args.size(), 8U);
const VarNode* buffer_var_d = op->args[0].as<VarNode>();
const VarNode* buffer_var_a = op->args[2].as<VarNode>();
const VarNode* buffer_var_b = op->args[4].as<VarNode>();
const VarNode* buffer_var_c = op->args[6].as<VarNode>();
CHECK(buffer_var_d);
CHECK(buffer_var_a);
CHECK(buffer_var_b);
CHECK(buffer_var_c);
// Check all fragment A, B, C and D have the same shape
CHECK(CheckShape(buffer_var_d, buffer_var_a));
CHECK(CheckShape(buffer_var_d, buffer_var_b));
CHECK(CheckShape(buffer_var_d, buffer_var_c));
}
}
private:
// A tool for checking shapes of two fragments
bool CheckShape(const VarNode* buffer1, const VarNode* buffer2) {
CHECK(fragment_getter.fragments.count(buffer1));
CHECK(fragment_getter.fragments.count(buffer2));
FragmentGetter::FragmentInfo info1 = fragment_getter.fragments.at(buffer1);
FragmentGetter::FragmentInfo info2 = fragment_getter.fragments.at(buffer2);
return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k;
}
// Fragment infomation
const FragmentGetter& fragment_getter;
};
// Store the metadata into attributes
class InferFragmenter : public StmtMutator {
public:
explicit InferFragmenter(const FragmentGetter& getter) : fragment_getter(getter) {}
Stmt VisitStmt_(const AllocateNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
const VarNode* buffer = op->buffer_var.get();
if (fragment_getter.fragments.count(buffer)) {
// Add attribute to fragments allocation
FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer);
// Add shape attribute to all fragments
std::string shape =
std::to_string(info.m) + ", " + std::to_string(info.n) + ", " + std::to_string(info.k);
PrimExpr shape_expr = StringImm(shape);
Stmt shape_attr = AttrStmt(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
if (info.layout != "") {
// Add shape attribute to matrix_a and matrix_b
Stmt layout_attr =
AttrStmt(op->buffer_var, attr::fragment_layout, StringImm(info.layout), shape_attr);
return layout_attr;
} else {
return shape_attr;
}
}
return stmt;
}
private:
// Fragment infomation
const FragmentGetter& fragment_getter;
};
Stmt InferFragment(Stmt stmt) {
FragmentGetter getter;
getter(stmt);
FragmentChecker checker(getter);
checker(stmt);
stmt = InferFragmenter(getter)(std::move(stmt));
return stmt;
}
namespace transform {
Pass InferFragment() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = InferFragment(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {});
}
TVM_REGISTER_GLOBAL("tir.transform.InferFragment").set_body_typed(InferFragment);
} // namespace transform
} // namespace tir
} // namespace tvm