blob: b057cb1509d72a8e32a112eb9bc2b5270107e8cd [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_trn.cc
*/
#include "codegen_trn.h"
#include <tvm/runtime/logging.h>
#include <tvm/tirx/expr.h>
#include <tvm/tirx/transform.h>
#include <algorithm>
#include <cmath>
#include <sstream>
#include <string>
#include <unordered_map>
#include <utility>
#include "../../../runtime/thread_storage_scope.h"
#include "../../../target/build_common.h"
namespace tvm {
namespace codegen {
namespace {
std::string PrintShapeAsList(const ffi::Array<PrimExpr>& shape) {
std::ostringstream os;
os << "[";
for (size_t i = 0; i < shape.size(); ++i) {
if (i > 0) os << ", ";
os << shape[i];
}
os << "]";
return os.str();
}
} // namespace
void CodeGenTrainium::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); }
CodeGenTrainium::CodeGenTrainium(Target target) : target_(target) {
decl_stream << "import neuronxcc.nki.language as nl\n";
decl_stream << "from neuronxcc.nki import baremetal, benchmark, simulate_kernel, trace\n";
decl_stream << "import numpy as np\n";
decl_stream << "import neuronxcc.nki.isa as nisa\n";
decl_stream << "import math\n";
decl_stream << "import neuronxcc.nki as nki\n";
decl_stream << "import neuronxcc.nki.typing as nt\n";
decl_stream << "import neuronxcc.nki.compiler as ncc\n";
decl_stream << "@nki.compiler.enable_stack_allocator\n";
decl_stream << "@nki.compiler.skip_middle_end_transformations\n";
decl_stream << "@baremetal(experimental_flags='enable-mutable-parameter', "
"additional_compile_opt='--internal-skip-backend-allocation-opt-nki')\n";
opcode_map_ = {{"sqrt", "nki.language.sqrt"}, {"add", "nki.language.add"},
{"sub", "nki.language.subtract"}, {"mul", "nki.language.multiply"},
{"max", "nki.language.maximum"}, {"min", "nki.language.minimum"},
{"exp", "nki.language.exp"}};
}
void CodeGenTrainium::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
// NOTE: There is no inter-function calls among Trainium kernels.
// For now we keep the Trainium codegen without inter-function call
// process.
// We can switch to follow the flow with inter-function call process
// after the Trainium function declaration is properly printed.
// In Trainium, for PrimFuncs with signature
// def func(A: Buffer, B: Buffer, x: int, y: float) -> None
// where there are trailing pod parameters, the codegen emits a struct
// struct func_params{ x: int; y: float; }
// for the function. In the flow of inter-function call process,
// the struct will be emitted for every time a function is declared.
// So consequently there are duplicate appearances of a same struct,
// which makes the Trainium compiler unable to recognize.
// clear previous generated state.
this->InitFuncState(func);
buffer_idmap_.clear();
data_buffer_idmap_.clear();
data_decl_buffer_map_.clear();
// skip the first underscore, so SSA variable starts from _1
name_supply_->FreshName("v_");
// add to alloc buffer type.
auto global_symbol = func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
TVM_FFI_ICHECK(global_symbol.has_value())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
// Function header.
this->stream << "def " << static_cast<std::string>(global_symbol.value()) << "(";
// Buffer arguments
auto num_inputs = func->GetAttr<int64_t>(tvm::attr::kNumInputs);
TVM_FFI_ICHECK(num_inputs.has_value());
std::vector<std::string> output_vids;
size_t num_buffer = 0;
for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) {
Var v = func->params[i];
if (!v.dtype().is_handle()) {
LOG(FATAL) << "Trainium codegen currently only support buffer arguments";
};
std::string vid = AllocVarID(v.get());
if (i >= static_cast<size_t>(num_inputs.value())) {
this->stream << vid << ": nt.mutable_tensor, ";
output_vids.push_back(vid);
} else {
this->stream << vid << ", ";
}
}
// the function scope.
stream << "):\n";
int func_scope = this->BeginScope();
this->PrintStmt(func->body);
this->PrintIndent();
stream << "return ";
for (size_t i = 0; i < output_vids.size(); i++) {
if (i != 0) {
stream << ", ";
}
stream << output_vids[i];
}
this->EndScope(func_scope);
}
void CodeGenTrainium::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
TVM_FFI_ICHECK(lanes == 1) << "Trainium codegen does not support vector types";
TVM_FFI_ICHECK(!t.is_handle()) << "Trainium codegen does not support handle type";
TVM_FFI_ICHECK(!t.is_void()) << "Trainium codegen does not support void type";
if (t == DataType::Bool()) {
os << "np.bool";
return;
}
if (t.is_float()) {
switch (t.bits()) {
case 16:
os << "np.float16";
break;
case 32:
os << "np.float32";
break;
default:
LOG(FATAL) << "Trainium codegen does not support float type with bits " << t.bits();
break;
}
return;
}
if (t.is_uint() || t.is_int()) {
if (t.bits() == 1) {
os << "np.bool";
return;
}
os << "np.";
if (t.is_uint()) {
os << 'u';
}
switch (t.bits()) {
case 8:
os << "int8";
break;
case 16:
os << "int16";
break;
case 32:
os << "int32";
break;
case 64:
os << "int64";
break;
default:
LOG(FATAL) << "Trainium codegen does not support int type with bits " << t.bits();
break;
}
return;
}
if (t.is_bfloat16()) {
os << "nl.bfloat16";
return;
}
LOG(FATAL) << "Cannot convert type " << t << " to Trainium type";
}
std::string CodeGenTrainium::GetStorageScopeStr(const std::string& scope) { // NOLINT(*)
if (scope == "global") {
return "nl.hbm";
} else if (scope == "trn.sbuf") {
return "nl.sbuf";
} else if (scope == "trn.psum") {
return "nl.psum";
} else {
LOG(FATAL) << "Unknown storage scope `" << scope << "`";
return "";
}
}
void CodeGenTrainium::VisitStmt_(const AllocBufferNode* op) {
TVM_FFI_ICHECK(op->buffer.defined());
std::string vid = AllocVarID(op->buffer->data.get());
this->PrintIndent();
auto scope = GetPtrStorageScope(op->buffer->data);
std::ostringstream dtype_os;
PrintType(op->buffer->dtype, dtype_os);
std::string dtype_str = dtype_os.str();
if (scope == "trn.psum") {
stream << vid << " = nl.ndarray(shape=[";
TVM_FFI_ICHECK(op->buffer->shape.size() == 3);
stream << PrintExpr(op->buffer->shape[0]) << ", nl.par_dim(" << PrintExpr(op->buffer->shape[1])
<< "), " << PrintExpr(op->buffer->shape[2]) << "], dtype=" << dtype_str << ", buffer=";
} else {
stream << vid << " = nl.ndarray(shape=" << PrintShapeAsList(op->buffer->shape)
<< ", dtype=" << dtype_str << ", buffer=";
}
Array<PrimExpr> addr;
if (auto allocated_addr = op->annotations.Get(tirx::attr::buffer_allocated_addr)) {
addr = Downcast<Array<PrimExpr>>(allocated_addr.value());
} else {
// AllocBuffer is a leaf stmt after rebase; in that path allocated_addr is carried by Buffer.
addr = op->buffer->allocated_addr;
}
if (addr.empty()) {
stream << GetStorageScopeStr(scope) << ")\n";
} else {
if (scope == "trn.psum") {
TVM_FFI_ICHECK(addr.size() == 2);
TVM_FFI_ICHECK(addr[0]->IsInstance<IntImmNode>())
<< "allocated_addr[0] must be a constant integer, got: " << addr[0];
TVM_FFI_ICHECK(addr[1]->IsInstance<IntImmNode>())
<< "allocated_addr[1] must be a constant integer, got: " << addr[1];
int64_t base_bank = Downcast<IntImm>(addr[0])->value;
int64_t base_addr = Downcast<IntImm>(addr[1])->value;
stream << "ncc.psum.mod_alloc(base_bank=" << base_bank << ", base_addr=" << base_addr;
stream << ", num_bank_tiles=(" << op->buffer->shape[0] << ",)))\n";
} else {
TVM_FFI_ICHECK(addr.size() == 1);
TVM_FFI_ICHECK(addr[0]->IsInstance<IntImmNode>())
<< "allocated_addr[0] must be a constant integer, got: " << addr[0];
int64_t base_addr = Downcast<IntImm>(addr[0])->value;
stream << "ncc.sbuf.mod_alloc(base_addr=" << base_addr << "))\n";
}
}
}
void CodeGenTrainium::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == tirx::attr::tensorized_nki_instruction) {
ctx_.tensorizing = true;
ctx_.mask = PrimExpr(nullptr);
ctx_.loopvar2dim.clear();
ctx_.is_matmul_input = false;
}
this->PrintStmt(op->body);
if (op->attr_key == tirx::attr::tensorized_nki_instruction) {
ctx_.tensorizing = false;
}
}
void CodeGenTrainium::VisitStmt_(const ForNode* op) {
bool is_outermost_loop = is_outermost_loop_;
is_outermost_loop_ = false;
std::string extent = PrintExpr(op->extent);
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
TVM_FFI_ICHECK(is_zero(op->min));
if (ctx_.tensorizing) {
stream << vid << " = nl.arange(" << extent << ")\n";
if (op->annotations.count("nki_dim")) {
ctx_.loopvar2dim[op->loop_var.get()] = Downcast<ffi::String>(op->annotations["nki_dim"]);
}
ctx_.tensorized_loop_vars.insert(op->loop_var.get());
TVM_FFI_ICHECK(ctx_.loopvar2dim.empty() ||
ctx_.loopvar2dim.size() == ctx_.tensorized_loop_vars.size())
<< "nki_dim attribute must be specified for all tensorized loop variables or none of them";
PrintStmt(op->body);
ctx_.tensorized_loop_vars.erase(op->loop_var.get());
} else {
if (is_outermost_loop) {
stream << "for " << vid << " in nl.sequential_range(" << extent
<< ", body_no_reorder=True):\n";
} else {
stream << "for " << vid << " in nl.sequential_range(" << extent << "):\n";
}
int for_scope = BeginScope();
PrintStmt(op->body);
EndScope(for_scope);
}
is_outermost_loop_ = is_outermost_loop;
}
std::string CodeGenTrainium::PrintIndices(const Array<PrimExpr>& indices) {
std::ostringstream os;
ctx_.buffer_index = 0;
ctx_.used_var_cnt = 0;
for (size_t i = 0; i < indices.size(); ++i) {
PreOrderVisit(indices[i], [&](const ffi::ObjectRef& node) {
if (const auto* v = node.as<VarNode>()) {
if (ctx_.tensorized_loop_vars.count(v)) {
ctx_.used_var_cnt++;
}
}
return true;
});
}
for (size_t i = 0; i < indices.size(); ++i) {
if (i != 0) {
os << ", ";
}
os << PrintExpr(indices[i]);
}
ctx_.buffer_index = -1;
return os.str();
}
void CodeGenTrainium::VisitStmt_(const BufferStoreNode* op) {
LOG(FATAL) << "Trainium codegen does not support buffer store";
}
void CodeGenTrainium::VisitStmt_(const EvaluateNode* op) {
if (is_const_int(op->value)) return;
std::string vid = this->PrintExpr(op->value);
if (vid != "") {
this->PrintIndent();
this->stream << vid << "\n";
}
}
void CodeGenTrainium::VisitExpr_(const BufferLoadNode* op, std::ostream& os) {
std::string buffer_str;
if (buffer_idmap_.count(op->buffer)) {
buffer_str = buffer_idmap_[op->buffer];
} else {
buffer_str = GetVarID(op->buffer->data.get());
}
os << buffer_str << "[";
os << PrintIndices(op->indices);
os << "]";
}
std::string PrintBool(bool b) { return b ? "True" : "False"; }
void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
TVM_FFI_ICHECK(!op->op.as<GlobalVarNode>())
<< "CodegenTrainium does not support inter-function calls, "
<< "but expression " << ffi::GetRef<Call>(op) << " calls PrimFunc " << op->op;
const auto* op_node = op->op.as<OpNode>();
auto is_op = [&](const Op& compat, const char* canonical_name) {
return op->op.same_as(compat) || (op_node != nullptr && op_node->name == canonical_name);
};
static const Op& nki_matmul_op = Op::Get("tirx.nki.matmul");
static const Op& nki_load_op = Op::Get("tirx.nki.load");
static const Op& nki_store_op = Op::Get("tirx.nki.store");
static const Op& nki_tensor_copy_op = Op::Get("tirx.nki.tensor_copy");
static const Op& nki_activation_op = Op::Get("tirx.nki.activation");
static const Op& nki_reciprocal_op = Op::Get("tirx.nki.reciprocal");
static const Op& nki_tensortensor_op = Op::Get("tirx.nki.tensortensor");
static const Op& nki_tensorscalar_op = Op::Get("tirx.nki.tensorscalar");
static const Op& nki_memset_op = Op::Get("tirx.nki.memset");
static const Op& nki_tensorreduce_op = Op::Get("tirx.nki.tensorreduce");
static const Op& nki_activation_reduce_op = Op::Get("tirx.nki.activation_reduce");
static const Op& nki_tensorscalar_reduce_op = Op::Get("tirx.nki.tensorscalar_reduce");
static const Op& nki_identity_op = Op::Get("tirx.nki.identity");
static const Op& nki_scalar_tensor_tensor_op = Op::Get("tirx.nki.scalar_tensor_tensor");
static const Op& nki_scalar_tensor_scalar_op = Op::Get("tirx.nki.scalar_tensor_scalar");
static const Op& nki_affine_select_op = Op::Get("tirx.nki.affine_select");
if (is_op(nki_matmul_op, "tirx.nki.matmul")) {
TVM_FFI_ICHECK_EQ(op->args.size(), 4);
std::string accum = is_one(op->args[3]) ? " += " : " = ";
os << PrintExpr(op->args[0]) << accum;
ctx_.is_matmul_input = true;
os << "nisa.nc_matmul(" << PrintExpr(op->args[1]) << "," << PrintExpr(op->args[2]);
} else if (is_op(nki_load_op, "tirx.nki.load")) {
TVM_FFI_ICHECK_EQ(op->args.size(), 2);
os << PrintExpr(op->args[0]) << " = nl.load(" << PrintExpr(op->args[1]);
} else if (is_op(nki_store_op, "tirx.nki.store")) {
TVM_FFI_ICHECK_EQ(op->args.size(), 2);
os << "nl.store(" << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]);
} else if (is_op(nki_tensor_copy_op, "tirx.nki.tensor_copy")) {
TVM_FFI_ICHECK_EQ(op->args.size(), 2);
os << PrintExpr(op->args[0]) << " = nisa.tensor_copy(" << PrintExpr(op->args[1]);
} else if (is_op(nki_activation_op, "tirx.nki.activation")) {
TVM_FFI_ICHECK_EQ(op->args.size(), 5);
// nki_activation(result, data, opcode, bias, scale)
TVM_FFI_ICHECK(opcode_map_.count(op->args[2].as<StringImmNode>()->value));
std::string nki_op = opcode_map_[op->args[2].as<StringImmNode>()->value];
os << PrintExpr(op->args[0]) << " = nisa.activation(op=" << nki_op
<< ", data=" << PrintExpr(op->args[1]) << ",";
os << "bias=" << PrintExpr(op->args[3]) << ", scale=" << PrintExpr(op->args[4]);
} else if (is_op(nki_reciprocal_op, "tirx.nki.reciprocal")) {
TVM_FFI_ICHECK_EQ(op->args.size(), 2);
os << PrintExpr(op->args[0]) << " = nisa.reciprocal(" << PrintExpr(op->args[1]);
} else if (is_op(nki_tensortensor_op, "tirx.nki.tensortensor")) {
TVM_FFI_ICHECK_EQ(op->args.size(), 4);
// nki_tensortensor(result, data1, data2, opcode)
TVM_FFI_ICHECK(opcode_map_.count(op->args[3].as<StringImmNode>()->value));
std::string nki_op = opcode_map_[op->args[3].as<StringImmNode>()->value];
os << PrintExpr(op->args[0]) << " = nisa.tensor_tensor(" << PrintExpr(op->args[1]) << ", ";
os << PrintExpr(op->args[2]) << ", op=" << nki_op;
} else if (is_op(nki_tensorscalar_op, "tirx.nki.tensorscalar")) {
TVM_FFI_ICHECK_EQ(op->args.size(), 5);
// nki_tensorscalar(result, operand0, operand1, opcode, reverse)
TVM_FFI_ICHECK(opcode_map_.count(op->args[3].as<StringImmNode>()->value));
std::string nki_op = opcode_map_[op->args[3].as<StringImmNode>()->value];
bool reverse = op->args[4].as<IntImmNode>()->value != 0;
os << PrintExpr(op->args[0]) << " = nisa.tensor_scalar(" << PrintExpr(op->args[1])
<< ", operand0=";
os << PrintExpr(op->args[2]) << ", op0=" << nki_op << ", reverse0=" << PrintBool(reverse);
} else if (is_op(nki_memset_op, "tirx.nki.memset")) {
TVM_FFI_ICHECK_GE(op->args.size(), 2);
// result, value
os << PrintExpr(op->args[0]) << " = " << PrintExpr(op->args[1]);
TVM_FFI_ICHECK(!ctx_.mask.defined()) << "memset cannot have mask";
return;
} else if (is_op(nki_tensorreduce_op, "tirx.nki.tensorreduce")) {
TVM_FFI_ICHECK(op->args.size() >= 5)
<< "nki_tensorreduce expects at least 5 arguments, but got " << op->args.size();
// nki_tensorreduce(result, data, opcode, negate, *axes)
TVM_FFI_ICHECK(opcode_map_.count(op->args[2].as<StringImmNode>()->value));
std::string nki_op = opcode_map_[op->args[2].as<StringImmNode>()->value];
bool negate = op->args[3].as<IntImmNode>()->value != 0;
Array<PrimExpr> axes(op->args.begin() + 4, op->args.end());
os << PrintExpr(op->args[0]) << " = nisa.tensor_reduce(data=" << PrintExpr(op->args[1])
<< ", op=" << nki_op << ", negate=" << PrintBool(negate) << ", axis=" << axes;
} else if (is_op(nki_activation_reduce_op, "tirx.nki.activation_reduce")) {
TVM_FFI_ICHECK(op->args.size() == 7)
<< "nki_activation_reduce expects 7 arguments, but got " << op->args.size();
// nki_activation_reduce(reduce_res, act_res, data, opcode, reduce_opcode, bias, scale)
TVM_FFI_ICHECK(opcode_map_.count(op->args[3].as<StringImmNode>()->value));
std::string nki_op = opcode_map_[op->args[3].as<StringImmNode>()->value];
TVM_FFI_ICHECK(opcode_map_.count(op->args[4].as<StringImmNode>()->value));
std::string reduce_nki_op = opcode_map_[op->args[4].as<StringImmNode>()->value];
os << PrintExpr(op->args[1]) << " = nisa.activation_reduce(data=" << PrintExpr(op->args[2])
<< ", op=" << nki_op;
os << ", reduce_op=" << reduce_nki_op << ", reduce_res=" << PrintExpr(op->args[0])
<< ", bias=" << PrintExpr(op->args[5]) << ", scale=" << PrintExpr(op->args[6]);
} else if (is_op(nki_tensorscalar_reduce_op, "tirx.nki.tensorscalar_reduce")) {
TVM_FFI_ICHECK(op->args.size() == 7)
<< "nki_tensorscalar_reduce expects 7 arguments, but got " << op->args.size();
// nki_tensorscalar_reduce(reduce_res, tensorscalar_res, operand0, operand1, opcode,
// reduce_opcode, reverse)
TVM_FFI_ICHECK(opcode_map_.count(op->args[4].as<StringImmNode>()->value));
std::string nki_op = opcode_map_[op->args[4].as<StringImmNode>()->value];
TVM_FFI_ICHECK(opcode_map_.count(op->args[5].as<StringImmNode>()->value));
std::string reduce_nki_op = opcode_map_[op->args[5].as<StringImmNode>()->value];
bool reverse = op->args[6].as<IntImmNode>()->value != 0;
os << PrintExpr(op->args[1]) << " = nisa.tensor_scalar_reduce(data=" << PrintExpr(op->args[2])
<< ", op0=" << nki_op << ", operand0=" << PrintExpr(op->args[3])
<< ", reduce_op=" << reduce_nki_op << ", reduce_res=" << PrintExpr(op->args[0])
<< ", reverse0=" << PrintBool(reverse);
} else if (is_op(nki_identity_op, "tirx.nki.identity")) {
// nki_identity(result, size)
TVM_FFI_ICHECK_EQ(op->args.size(), 2);
auto identity_np_name = name_supply_->FreshName("identity_np");
os << identity_np_name << " = nl.shared_constant(np.identity(" << PrintExpr(op->args[1])
<< ", dtype=np.int8), dtype=nl.bfloat16)" << std::endl;
for (int i = 0; i < indent_; ++i) {
os << ' ';
}
os << PrintExpr(op->args[0]) << " = nl.load(" << identity_np_name;
} else if (is_op(nki_scalar_tensor_tensor_op, "tirx.nki.scalar_tensor_tensor")) {
TVM_FFI_ICHECK_EQ(op->args.size(), 8);
// nki_scalar_tensor_tensor(result, data, operand0, operand1, opcode0, opcode1, reverse0,
// reverse1)
TVM_FFI_ICHECK(opcode_map_.count(op->args[4].as<StringImmNode>()->value));
std::string nki_op0 = opcode_map_[op->args[4].as<StringImmNode>()->value];
TVM_FFI_ICHECK(opcode_map_.count(op->args[5].as<StringImmNode>()->value));
std::string nki_op1 = opcode_map_[op->args[5].as<StringImmNode>()->value];
bool reverse0 = op->args[6].as<IntImmNode>()->value != 0;
bool reverse1 = op->args[7].as<IntImmNode>()->value != 0;
os << PrintExpr(op->args[0]) << " = nisa.scalar_tensor_tensor(data=" << PrintExpr(op->args[1])
<< ", operand0=" << PrintExpr(op->args[2]) << ", op0=" << nki_op0
<< ", reverse0=" << PrintBool(reverse0) << ", operand1=" << PrintExpr(op->args[3])
<< ", op1=" << nki_op1 << ", reverse1=" << PrintBool(reverse1);
} else if (is_op(nki_scalar_tensor_scalar_op, "tirx.nki.scalar_tensor_scalar")) {
TVM_FFI_ICHECK_EQ(op->args.size(), 8);
// nki_scalar_tensor_scalar(result, data, operand0, operand1, opcode0, opcode1, reverse0,
// reverse1)
TVM_FFI_ICHECK(opcode_map_.count(op->args[4].as<StringImmNode>()->value));
std::string nki_op0 = opcode_map_[op->args[4].as<StringImmNode>()->value];
TVM_FFI_ICHECK(opcode_map_.count(op->args[5].as<StringImmNode>()->value));
std::string nki_op1 = opcode_map_[op->args[5].as<StringImmNode>()->value];
bool reverse0 = op->args[6].as<IntImmNode>()->value != 0;
bool reverse1 = op->args[7].as<IntImmNode>()->value != 0;
os << PrintExpr(op->args[0]) << " = nisa.tensor_scalar(data=" << PrintExpr(op->args[1])
<< ", operand0=" << PrintExpr(op->args[2]) << ", op0=" << nki_op0
<< ", reverse0=" << PrintBool(reverse0) << ", operand1=" << PrintExpr(op->args[3])
<< ", op1=" << nki_op1 << ", reverse1=" << PrintBool(reverse1);
} else if (is_op(nki_affine_select_op, "tirx.nki.affine_select")) {
TVM_FFI_ICHECK_EQ(op->args.size(), 4);
// nki_affine_select(result, pred, true_value, false_value)
os << PrintExpr(op->args[0]) << " = nisa.affine_select(pred=" << PrintExpr(op->args[1])
<< ", on_true_tile=" << PrintExpr(op->args[2])
<< ", on_false_value=" << PrintExpr(op->args[3]);
} else {
LOG(FATAL) << "Trainium codegen does not support call to " << op->op;
}
if (ctx_.mask.defined()) {
PreOrderVisit(ctx_.mask, [&](const ffi::ObjectRef& node) {
if (const auto* v = node.as<VarNode>()) {
if (ctx_.tensorized_loop_vars.count(v)) {
TVM_FFI_ICHECK(ctx_.loopvar2dim.count(v))
<< "nki_dim must be specified for tensorized loop variables used in mask. However, "
"it is not specified for "
<< ffi::GetRef<Var>(v);
auto dim_str = ctx_.loopvar2dim[v];
TVM_FFI_ICHECK(dim_str == "P" || dim_str == "F")
<< "Only nki_dim = P or F is allowed for tensorized loop variables used in mask. "
"However, "
<< ffi::GetRef<Var>(v) << " has nki_dim = " << dim_str;
}
}
return true;
});
os << ", mask=" << PrintExpr(ctx_.mask);
}
os << ")";
}
void CodeGenTrainium::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
std::ostringstream temp;
if (std::isinf(op->value)) {
if (op->value < 0) {
temp << "-";
}
temp << "math.inf";
} else if (std::isnan(op->value)) {
LOG(FATAL) << "Trainium codegen does not support NaN";
} else {
temp << std::scientific << op->value;
}
MarkConst(temp.str());
os << temp.str();
}
void CodeGenTrainium::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLINT(*)
os << GetVarID(op);
if (!ctx_.tensorized_loop_vars.count(op)) {
// this var is not a tensorized loop variable
return;
}
int total_dim_num, dim;
if (ctx_.loopvar2dim.count(op)) {
// nki_dim is specified for this loop variable
auto dim_str = ctx_.loopvar2dim[op];
if (dim_str == "P") {
dim = 0;
} else if (dim_str == "F" || dim_str == "rhs_F") {
dim = 1;
} else if (dim_str == "lhs_F") {
dim = ctx_.is_matmul_input ? 1 : 0;
} else {
LOG(FATAL) << "Invalid nki_dim: " << dim_str;
}
total_dim_num = 2;
} else {
// nki_dim is not specified for this loop variable
// we need to use the buffer dimension where the variable appears
if (ctx_.buffer_index == -1) {
// this var is not under BufferLoad. We don't know which dim it belongs to.
return;
}
dim = ctx_.buffer_index;
total_dim_num = ctx_.used_var_cnt;
}
os << "[";
for (int i = 0; i < total_dim_num; i++) {
if (i == dim) {
os << ":, ";
} else {
os << "None, ";
}
}
os << "]";
ctx_.buffer_index++;
}
void CodeGenTrainium::VisitExpr_(const CastNode* op, std::ostream& os) {
ctx_.dst_dtype = op->dtype;
CodeGenTrainium::VisitExpr(op->value, os);
}
void CodeGenTrainium::VisitExpr_(const FloorDivNode* op, std::ostream& os) {
os << PrintExpr(op->a) << " // " << PrintExpr(op->b);
}
void CodeGenTrainium::VisitExpr_(const FloorModNode* op, std::ostream& os) {
os << PrintExpr(op->a) << " % " << PrintExpr(op->b);
}
void CodeGenTrainium::VisitStmt_(const DeclBufferNode* op) {
if (op->buffer.scope() == "trn.psum" || op->buffer.scope() == "trn.sbuf") {
return;
}
const VarNode* data = op->buffer->data.get();
auto it = data_buffer_idmap_.find(data);
if (it != data_buffer_idmap_.end()) {
const Buffer& prev_buffer = data_decl_buffer_map_.at(data);
if (ffi::StructuralEqual()(prev_buffer->shape, op->buffer->shape) &&
prev_buffer->dtype == op->buffer->dtype) {
buffer_idmap_[op->buffer] = it->second;
return;
}
}
std::string data_vid = GetVarID(data);
std::string buffer_vid = name_supply_->FreshName(data_vid + "_buffer");
buffer_idmap_[op->buffer] = buffer_vid;
data_buffer_idmap_[data] = buffer_vid;
data_decl_buffer_map_[data] = op->buffer;
PrintIndent();
stream << buffer_vid << " = " << data_vid << ".reshape(" << PrintShapeAsList(op->buffer->shape)
<< ")\n";
}
ffi::Module BuildTrainium(IRModule mod, Target target) {
bool output_ssa = false;
std::ostringstream source_maker;
std::unordered_map<std::string, std::string> smap;
static auto fTrainium_compile = ffi::Function::GetGlobal("tvm_callback_Trainium_compile");
std::string fmt = fTrainium_compile.has_value() ? "Trainiumlib" : "Trainium";
for (auto kv : mod->functions) {
TVM_FFI_ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTrainium: Can only take PrimFunc";
auto global_symbol = kv.second->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
TVM_FFI_ICHECK(global_symbol.has_value());
std::string func_name = global_symbol.value();
source_maker << "# Function: " << func_name << "\n";
CodeGenTrainium cg(target);
cg.Init(output_ssa);
auto f = Downcast<PrimFunc>(kv.second);
cg.AddFunction(kv.first, f);
std::string fsource = cg.Finish();
source_maker << fsource << "\n";
smap[func_name] = fsource;
}
return codegen::DeviceSourceModuleCreate(source_maker.str(), fmt, ExtractFuncInfo(mod), "nki");
}
void CodeGenTrainium::VisitStmt_(const IfThenElseNode* op) {
if (ctx_.tensorizing) {
TVM_FFI_ICHECK(!op->else_case.defined()) << "Else not allowed in tensorized instruction";
TVM_FFI_ICHECK(!ctx_.mask.defined()) << "Only one if stmt allowed in tensorized instruction";
ctx_.mask = op->condition;
VisitStmt(op->then_case);
return;
}
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if " << cond << " :\n";
int then_scope = BeginScope();
PrintStmt(op->then_case);
this->EndScope(then_scope);
if (op->else_case) {
PrintIndent();
stream << "else:\n";
int else_scope = BeginScope();
PrintStmt(op->else_case.value());
this->EndScope(else_scope);
}
}
void CodeGenTrainium::VisitExpr_(const AndNode* op, std::ostream& os) {
os << PrintExpr(op->a) << " & " << PrintExpr(op->b);
}
void CodeGenTrainium::VisitExpr_(const OrNode* op, std::ostream& os) {
os << PrintExpr(op->a) << " | " << PrintExpr(op->b);
}
void RegisterTRNCodegen() {
static bool registered = false;
if (registered) return;
registered = true;
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.trn", BuildTrainium);
}
} // namespace codegen
} // namespace tvm