blob: b69f76914495b0186f6694815dc8d8c0d7503b1c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_c.cc
*/
#include "codegen_c.h"
#include <tvm/arith/analyzer.h>
#include <cctype>
#include <iomanip>
#include "../../arith/pattern_match.h"
#include "codegen_params.h"
namespace tvm {
namespace codegen {
using namespace tir;
void CodeGenC::Init(bool output_ssa) { print_ssa_form_ = output_ssa; }
void CodeGenC::InitFuncState(const PrimFunc& f) {
alloc_storage_scope_.clear();
handle_data_type_.clear();
CodeGenSourceBase::ClearFuncState();
}
void CodeGenC::ReserveKeywordsAsUnique() {
// skip the first underscore, so SSA variable starts from _1
name_supply_->ReserveName("_");
name_supply_->ReserveName("extern");
name_supply_->ReserveName("void");
name_supply_->ReserveName("int");
name_supply_->ReserveName("float");
name_supply_->ReserveName("double");
name_supply_->ReserveName("char");
name_supply_->ReserveName("unsigned");
name_supply_->ReserveName("short");
name_supply_->ReserveName("long");
name_supply_->ReserveName("if");
name_supply_->ReserveName("else");
name_supply_->ReserveName("switch");
name_supply_->ReserveName("case");
name_supply_->ReserveName("default");
name_supply_->ReserveName("for");
name_supply_->ReserveName("do");
name_supply_->ReserveName("while");
name_supply_->ReserveName("goto");
name_supply_->ReserveName("register");
name_supply_->ReserveName("continue");
name_supply_->ReserveName("break");
name_supply_->ReserveName("typedef");
name_supply_->ReserveName("struct");
name_supply_->ReserveName("enum");
name_supply_->ReserveName("union");
name_supply_->ReserveName("return");
}
void CodeGenC::AddFunction(const PrimFunc& f) {
// clear previous generated state.
this->InitFuncState(f);
// reserve keywords
ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
this->PrintFuncPrefix();
this->PrintExtraAttrs(f);
this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";
for (size_t i = 0; i < f->params.size(); ++i) {
tir::Var v = f->params[i];
std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", ";
if (v.dtype().is_handle()) {
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
PrintType(GetType(v), stream);
// Register handle data type
// TODO(tvm-team): consider simply keep type info in the
// type annotation(via a normalizing rewriting).
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(v.get(), prim->dtype);
}
}
if (no_alias) {
PrintRestrict(v, stream);
}
} else {
PrintType(GetType(v), stream);
}
stream << ' ' << vid;
}
stream << ") {\n";
this->PreFunctionBody(f);
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->PrintFinalReturn();
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
}
void CodeGenC::PrintFuncPrefix() { stream << "void"; }
void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}
void CodeGenC::PrintFinalReturn() {}
std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }
void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*)
if (print_ssa_form_) {
std::ostringstream temp;
VisitExpr(n, temp);
os << SSAGetID(temp.str(), n.dtype());
} else {
VisitExpr(n, os);
}
}
static bool CheckOutermostBracketMatch(const std::string& s);
void CodeGenC::PrintSSAAssign(const std::string& target, const std::string& src, DataType t) {
PrintType(t, stream);
stream << ' ' << target << " = ";
if (CheckOutermostBracketMatch(src)) {
stream << src.substr(1, src.length() - 2);
} else {
stream << src;
}
stream << ";\n";
}
// Print a reference expression to a buffer.
std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) {
const VarNode* buffer_var = buffer->data.get();
std::ostringstream os;
std::string vid = GetVarID(buffer_var);
std::string scope;
if (alloc_storage_scope_.count(buffer_var)) {
scope = alloc_storage_scope_.at(buffer_var);
}
bool is_vol = IsVolatile(buffer_var);
auto ptr_cast = [this, is_vol, scope](DataType pointed_to) {
std::ostringstream ptr_os;
ptr_os << "(";
if (is_vol) {
ptr_os << "volatile ";
}
if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, ptr_os);
}
PrintType(pointed_to, ptr_os);
ptr_os << "*)";
return ptr_os.str();
};
DataType buffer_element_dtype = buffer->dtype;
std::string buffer_str = vid;
if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) {
std::stringstream temp;
temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")";
buffer_str = temp.str();
}
std::string index_str = PrintExpr(index);
if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
// This is a special case, because CodegenCUDA::PrintType()
// returns "int" for bool and for 4-bit integers. In most cases,
// we divide by the number of lanes to determine the index.
// However, the backing type for scalar int4 and scalar bool is
// int32. Therefore, we need to divide by the ratio of their
// sizes in that case.
int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes();
os << "*("
<< "(" << ptr_cast(t) << vid << ")"
<< " + " << index_str << " / " << div_factor << ")";
} else if (t == buffer_element_dtype) {
os << buffer_str << "[" << index_str << "]";
} else {
os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
}
return os.str();
}
// Print a reference expression to a buffer.
std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index,
int kind) {
if (kind < builtin::kArrKindBound_) {
std::ostringstream os;
os << "(((DLTensor*)";
this->PrintExpr(buffer, os);
os << ")";
if (kind == builtin::kArrAddr) {
os << " + ";
this->PrintExpr(index, os);
os << ")";
return os.str();
}
os << '[';
this->PrintExpr(index, os);
os << "].";
// other case: get fields.
switch (kind) {
case builtin::kArrData:
os << "data";
break;
case builtin::kArrShape:
os << "shape";
break;
case builtin::kArrStrides:
os << "strides";
break;
case builtin::kArrNDim:
os << "ndim";
break;
case builtin::kArrTypeCode:
os << "dtype.code";
break;
case builtin::kArrTypeBits:
os << "dtype.bits";
break;
case builtin::kArrByteOffset:
os << "byte_offset";
break;
case builtin::kArrTypeLanes:
os << "dtype.lanes";
break;
case builtin::kArrDeviceId:
os << "device.device_id";
break;
case builtin::kArrDeviceType:
os << "device.device_type";
break;
default:
LOG(FATAL) << "unknown field code";
}
os << ')';
return os.str();
} else {
ICHECK_LT(kind, builtin::kTVMValueKindBound_);
std::ostringstream os;
os << "(((TVMValue*)";
this->PrintExpr(buffer, os);
os << ")[" << index << "].";
if (t.is_handle()) {
os << "v_handle";
} else if (t.is_float()) {
os << "v_float64";
} else if (t.is_int()) {
os << "v_int64";
} else {
LOG(FATAL) << "Do not know how to handle type" << t;
}
os << ")";
return os.str();
}
}
bool CodeGenC::HandleTypeMatch(const VarNode* buf_var, DataType t) const {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) return false;
return it->second == t;
}
void CodeGenC::RegisterHandleType(const VarNode* buf_var, DataType t) {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) {
handle_data_type_[buf_var] = t;
} else {
ICHECK(it->second == t) << "conflicting buf var type";
}
}
void CodeGenC::PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) { // NOLINT(*)
os << vec << ".s" << std::hex << i << std::dec;
}
void CodeGenC::PrintVecElemStore(const std::string& vec, DataType t, int i,
const std::string& value) {
this->PrintIndent();
stream << vec << ".s" << std::hex << i << " = " << value << ";\n" << std::dec;
}
std::string CodeGenC::GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) {
return GetBufferRef(t, buffer, base);
}
void CodeGenC::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base,
const std::string& value) {
std::string ref = GetBufferRef(t, buffer, base);
this->PrintIndent();
stream << ref << " = " << value << ";\n";
}
std::string CodeGenC::CastFromTo(std::string value, DataType from, DataType target) {
if (from == target) return value;
std::ostringstream os;
os << "((";
this->PrintType(target, os);
os << ")" << value << ")";
return os.str();
}
void CodeGenC::BindThreadIndex(const IterVar& iv) { LOG(FATAL) << "not implemented"; }
void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*)
}
void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
ICHECK_EQ(scope, "global");
}
inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
if (op->dtype == DataType::Int(32)) {
std::ostringstream temp;
temp << op->value;
p->MarkConst(temp.str());
os << temp.str();
} else {
os << "(";
p->PrintType(op->dtype, os);
os << ")" << op->value;
}
}
inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os,
CodeGenC* p) { // NOLINT(*)
if (dtype == DataType::UInt(32)) {
std::ostringstream temp;
temp << val << "U";
p->MarkConst(temp.str());
os << temp.str();
} else {
os << "(";
p->PrintType(dtype, os);
os << ")" << val;
}
}
inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
switch (op->dtype.bits()) {
case 64:
case 32: {
std::ostringstream temp;
temp << std::scientific << op->value;
if (op->dtype.bits() == 32) temp << 'f';
p->MarkConst(temp.str());
os << temp.str();
break;
}
case 16: {
os << '(';
p->PrintType(op->dtype, os);
os << ')' << std::scientific << op->value << 'f';
break;
}
default:
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
}
}
void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*)
os << "\"" << op->value << "\"";
}
template <typename T>
inline void PrintBinaryExpr(const T* op, const char* opstr,
std::ostream& os, // NOLINT(*)
CodeGenC* p) {
if (op->dtype.lanes() == 1) {
if (isalpha(opstr[0])) {
os << opstr << '(';
p->PrintExpr(op->a, os);
os << ", ";
p->PrintExpr(op->b, os);
os << ')';
} else {
os << '(';
p->PrintExpr(op->a, os);
os << ' ' << opstr << ' ';
p->PrintExpr(op->b, os);
os << ')';
}
} else {
p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os);
}
}
inline void PrintBinaryIntrinsic(const CallNode* op, const char* opstr,
std::ostream& os, // NOLINT(*)
CodeGenC* p) {
if (op->dtype.lanes() == 1) {
ICHECK_EQ(op->args.size(), 2U);
os << '(';
p->PrintExpr(op->args[0], os);
os << opstr;
p->PrintExpr(op->args[1], os);
os << ')';
} else {
p->PrintVecBinaryOp(opstr, op->dtype, op->args[0], op->args[1], os);
}
}
void CodeGenC::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*)
std::stringstream value;
this->PrintExpr(op->value, value);
os << CastFromTo(value.str(), op->value.dtype(), op->dtype);
}
void CodeGenC::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLINT(*)
os << GetVarID(op);
}
void CodeGenC::VisitExpr_(const AddNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "+", os, this);
}
void CodeGenC::VisitExpr_(const SubNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "-", os, this);
}
void CodeGenC::VisitExpr_(const MulNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "*", os, this);
}
void CodeGenC::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "/", os, this);
}
void CodeGenC::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*)
if (op->dtype.is_int() || op->dtype.is_uint()) {
PrintBinaryExpr(op, "%", os, this);
} else {
ICHECK(op->dtype.is_float()) << "Expected floating point or integer dtype in Mod, but got "
<< op->dtype;
if (op->dtype.bits() == 32) {
PrintBinaryExpr(op, "fmodf", os, this);
} else if (op->dtype.bits() == 64) {
PrintBinaryExpr(op, "fmod", os, this);
} else {
ICHECK(false)
<< "Non single or double precision floating point in Mod, expected 32 or 64 bits but got "
<< op->dtype.bits() << " bits.";
}
}
}
void CodeGenC::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "min", os, this);
}
void CodeGenC::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "max", os, this);
}
void CodeGenC::VisitExpr_(const EQNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "==", os, this);
}
void CodeGenC::VisitExpr_(const NENode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "!=", os, this);
}
void CodeGenC::VisitExpr_(const LTNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<", os, this);
}
void CodeGenC::VisitExpr_(const LENode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<=", os, this);
}
void CodeGenC::VisitExpr_(const GTNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">", os, this);
}
void CodeGenC::VisitExpr_(const GENode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">=", os, this);
}
void CodeGenC::VisitExpr_(const AndNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "&&", os, this);
}
void CodeGenC::VisitExpr_(const OrNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "||", os, this);
}
void CodeGenC::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*)
os << '!';
PrintExpr(op->a, os);
}
void CodeGenC::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg, std::ostream& os) { // NOLINT(*)
os << global_symbol << "(";
for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
this->PrintExpr(args[i], os);
if (i < args.size() - 1) {
os << ", ";
}
}
os << ")";
}
void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
if (auto* ptr_op = op->op.as<OpNode>()) {
auto call_op = GetRef<Op>(ptr_op);
if (op->op.same_as(builtin::tvm_check_return())) {
const CallNode* call = op->args[2].as<CallNode>();
os << "if (";
VisitExpr_(call, os);
os << " != ";
PrintExpr(op->args[0], os);
os << " ) return ";
PrintExpr(op->args[1], os);
} else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
ICHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
} else if (op_attr_global_symbol_.count(call_op)) {
// call extern if the op itself have a global symbol.
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op],
op->args, false, os);
} else if (op->op.same_as(builtin::bitwise_and())) {
PrintBinaryIntrinsic(op, " & ", os, this);
} else if (op->op.same_as(builtin::large_uint_imm())) {
ICHECK_EQ(op->args.size(), 2U);
uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
uint64_t val = (high << 32U) | low;
PrintUIntConst(op->dtype, val, os, this);
} else if (op->op.same_as(builtin::bitwise_xor())) {
PrintBinaryIntrinsic(op, " ^ ", os, this);
} else if (op->op.same_as(builtin::bitwise_or())) {
PrintBinaryIntrinsic(op, " | ", os, this);
} else if (op->op.same_as(builtin::bitwise_not())) {
ICHECK_EQ(op->args.size(), 1U);
os << "(~";
this->PrintExpr(op->args[0], os);
os << ')';
} else if (op->op.same_as(builtin::shift_left())) {
PrintBinaryIntrinsic(op, " << ", os, this);
} else if (op->op.same_as(builtin::shift_right())) {
PrintBinaryIntrinsic(op, " >> ", os, this);
} else if (op->op.same_as(builtin::if_then_else())) {
os << "(";
PrintExpr(op->args[0], os);
os << " ? ";
PrintExpr(op->args[1], os);
os << " : ";
PrintExpr(op->args[2], os);
os << ")";
} else if (op->op.same_as(builtin::address_of())) {
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
ICHECK(op->args.size() == 1 && load);
ICHECK_EQ(load->indices.size(), 1) << "CodeGenC only supports flat memory allocations.";
os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), load->indices[0]) << "))";
} else if (op->op.same_as(builtin::tvm_struct_get())) {
ICHECK_EQ(op->args.size(), 3U);
os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as<IntImmNode>()->value);
} else if (op->op.same_as(builtin::isnullptr())) {
ICHECK_EQ(op->args.size(), 1U);
os << "(";
this->PrintExpr(op->args[0], os);
os << " == NULL)";
} else if (op->op.same_as(builtin::reinterpret())) {
int ssa_scope = BeginScope();
std::string rhs = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
os << "(*(";
this->PrintType(op->dtype, os);
os << " *)(&(" << rhs << ")))";
EndScope(ssa_scope);
} else if (op->op.same_as(builtin::isnan())) {
os << "(";
this->PrintExpr(op->args[0], os);
os << " != ";
this->PrintExpr(op->args[0], os);
os << ")";
} else if (op->op.same_as(builtin::lookup_param())) {
ICHECK_EQ(op->args.size(), 1);
const StringImmNode* str = op->args[0].as<StringImmNode>();
ICHECK(str != nullptr);
os << "__tvm_param__" << str->value;
} else {
LOG(FATAL) << "Unresolved call " << op->op;
}
} else {
ICHECK(op->op.as<GlobalVarNode>());
LOG(FATAL) << "Do not yet support cross function call";
}
}
void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) { // NOLINT(*)
if (isalpha(op[0])) {
os << op << "(";
this->PrintExpr(lhs, os);
os << ", ";
this->PrintExpr(rhs, os);
os << ")";
} else {
os << "(";
this->PrintExpr(lhs, os);
os << ' ' << op << ' ';
this->PrintExpr(rhs, os);
os << ")";
}
}
void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
std::string symbol_name = op->buffer_var->name_hint;
int64_t num_elements = 1;
const auto& data = op->data.value();
for (int64_t dim : data.Shape()) {
num_elements *= dim;
}
decl_stream << "\n"
<< "#ifdef __cplusplus\n"
<< "extern \"C\" {\n"
<< "#endif\n"
<< "static const ";
PrintType(data.DataType(), decl_stream);
// Allocate the global static variable
decl_stream << " __attribute__((section(\".rodata.tvm\"), "
<< "aligned(" << constants_byte_alignment_->value << "))) " << symbol_name << "["
<< num_elements << "] = {\n";
NDArrayDataToC(data, 4, decl_stream);
decl_stream << "};\n"
<< "#ifdef __cplusplus\n"
<< "} // extern \"C\"\n"
<< "#endif\n";
var_idmap_[op->buffer_var.operator->()] = symbol_name;
this->PrintStmt(op->body);
}
void CodeGenC::VisitStmt_(const DeclBufferNode* op) { this->PrintStmt(op->body); }
void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Unexpected deprecated LoadNode. Use BufferLoadNode instead.";
}
void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*)
ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported.";
DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];
Var buffer_var = op->buffer->data;
DataType element_dtype = op->buffer->dtype;
int lanes = op->dtype.lanes();
// delcare type.
if (value_dtype.lanes() == element_dtype.lanes()) {
std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index);
HandleVolatileLoads(ref, op, os);
} else {
bool can_vector_load = false;
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) {
const RampNode* ramp = index.as<RampNode>();
ICHECK(ramp);
arith::ModularSet me = arith::Analyzer().modular_set(ramp->base);
// The condition: {k * coeff + base} divisible by the alignment for any k
if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes() == 0) {
can_vector_load = true;
}
}
if (can_vector_load) {
std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
HandleVolatileLoads(ref, op, os);
} else {
std::ostringstream svalue_expr;
std::string sindex = SSAGetID(PrintExpr(index), index.dtype());
std::string vid = GetVarID(buffer_var.get());
DataType elem_type = op->dtype.element_of();
for (int i = 0; i < lanes; ++i) {
std::ostringstream value_temp;
if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
value_temp << "((";
if (buffer_var.get()->dtype.is_handle()) {
auto it = alloc_storage_scope_.find(buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, value_temp);
}
}
PrintType(elem_type, value_temp);
value_temp << "*)" << vid << ')';
} else {
value_temp << vid;
}
value_temp << '[';
PrintVecElemLoad(sindex, index.dtype(), i, value_temp);
value_temp << ']';
PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr);
}
os << svalue_expr.str();
}
}
}
void CodeGenC::VisitStmt_(const StoreNode* op) {
LOG(FATAL) << "Unexpected deprecated StoreNode. Use BufferStoreNode instead.";
}
void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported.";
DataType value_dtype = op->value.dtype();
DataType element_dtype = op->buffer->dtype;
PrimExpr index_expr = op->indices[0];
Var buffer_var = op->buffer->data;
if (value_dtype.lanes() == element_dtype.lanes()) {
std::string value = this->PrintExpr(op->value);
std::string ref = this->GetBufferRef(value_dtype, op->buffer.get(), index_expr);
this->PrintIndent();
stream << ref << " = " << value << ";\n";
} else {
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, value_dtype.lanes()).Match(index_expr)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value);
} else {
// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
int vec_scope = BeginScope();
// store elements separately
std::string index = SSAGetID(PrintExpr(index_expr), index_expr.dtype());
std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype());
std::string vid = GetVarID(buffer_var.get());
for (int i = 0; i < value_dtype.lanes(); ++i) {
this->PrintIndent();
DataType elem_type = value_dtype.element_of();
if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
stream << "((";
if (buffer_var.get()->dtype.is_handle()) {
auto it = alloc_storage_scope_.find(buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
}
PrintType(elem_type, stream);
stream << "*)" << vid << ')';
} else {
stream << vid;
}
stream << '[';
PrintVecElemLoad(index, index_expr.dtype(), i, stream);
stream << "] = ";
PrintVecElemLoad(value, op->value.dtype(), i, stream);
stream << ";\n";
}
EndScope(vec_scope);
}
}
}
void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
ICHECK(deep_equal_(it->second->value, op->value))
<< "Let cannot bind the same var to two different values";
} else {
let_binding_[op->var] = op;
}
std::string value = PrintExpr(op->value);
var_idmap_[op->var.get()] = value;
os << PrintExpr(op->body);
}
void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
// constraint of current logic
ICHECK_EQ(op->base.dtype(), DataType::Int(32));
os << "((int" << op->lanes << ")(";
for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
if (i != op->lanes - 1) os << ", ";
}
os << "))";
}
void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
LOG(FATAL) << "Shuffle: not supported ";
}
void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Broadcast: not supported ";
}
void CodeGenC::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*)
os << "(";
PrintExpr(op->condition, os);
os << " ? ";
PrintExpr(op->true_value, os);
os << " : ";
PrintExpr(op->false_value, os);
os << ")";
}
void CodeGenC::VisitStmt_(const LetStmtNode* op) {
std::string value = PrintExpr(op->value);
if (print_ssa_form_) {
ICHECK(!var_idmap_.count(op->var.get()));
var_idmap_[op->var.get()] = value;
} else {
PrintIndent();
if (op->var.dtype() == DataType::Handle() && handle_data_type_.count(op->var.get())) {
PrintType(handle_data_type_.at(op->var.get()), stream);
stream << "* " << AllocVarID(op->var.get()) << " = (";
PrintType(handle_data_type_.at(op->var.get()), stream);
stream << "*)" << value << ";\n";
} else {
PrintType(op->var.dtype(), this->stream);
this->stream << ' ' << AllocVarID(op->var.get()) << " = " << value << ";\n";
}
}
PrintStmt(op->body);
}
void CodeGenC::VisitStmt_(const AllocateNode* op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
this->PrintIndent();
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
auto scope = GetPtrStorageScope(op->buffer_var);
alloc_storage_scope_[op->buffer_var.get()] = scope;
PrintStorageScope(scope, stream);
PrintType(op->dtype, stream);
stream << ' ' << vid << '[' << constant_size << "];\n";
RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
}
void CodeGenC::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) {
if (!var_idmap_.count(iv->var.get())) {
BindThreadIndex(iv);
}
}
} else if (op->attr_key == tir::attr::volatile_scope) {
const VarNode* v = op->node.as<VarNode>();
ICHECK(v);
volatile_buf_.insert(v);
} else if (op->attr_key == tir::attr::pragma_import_c) {
const StringImmNode* value = op->value.as<StringImmNode>();
ICHECK(value != nullptr);
decl_stream << value->value;
}
this->PrintStmt(op->body);
}
void CodeGenC::VisitStmt_(const AssertStmtNode* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
if (const auto* str = op->message.as<StringImmNode>()) {
// GLOG style check
stream << "ICHECK(" << cond << ") << \"" << str->value << "\";\n";
} else {
stream << "assert(" << cond << ");\n";
}
this->PrintStmt(op->body);
}
void CodeGenC::VisitStmt_(const ForNode* op) {
std::string extent = PrintExpr(op->extent);
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
ICHECK(is_zero(op->min));
stream << "for (";
PrintType(op->loop_var.dtype(), stream);
stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
PrintIndent();
stream << "}\n";
}
void CodeGenC::VisitStmt_(const WhileNode* op) {
PrintIndent();
stream << "while (" << PrintExpr(op->condition) << ") {\n";
int while_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(while_scope);
PrintIndent();
stream << "}\n";
}
void CodeGenC::VisitStmt_(const IfThenElseNode* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
if (cond[0] == '(' && cond[cond.length() - 1] == ')') {
stream << "if " << cond << " {\n";
} else {
stream << "if (" << cond << ") {\n";
}
int then_scope = BeginScope();
PrintStmt(op->then_case);
this->EndScope(then_scope);
if (op->else_case.defined()) {
PrintIndent();
stream << "} else {\n";
int else_scope = BeginScope();
PrintStmt(op->else_case);
this->EndScope(else_scope);
}
PrintIndent();
stream << "}\n";
}
void CodeGenC::VisitStmt_(const SeqStmtNode* op) {
for (Stmt stmt : op->seq) {
PrintStmt(stmt);
}
}
void CodeGenC::VisitStmt_(const EvaluateNode* op) {
if (is_const_int(op->value)) return;
const CallNode* call = op->value.as<CallNode>();
if (call) {
if (call->op.same_as(builtin::tvm_storage_sync())) {
this->PrintStorageSync(call);
return;
} else if (call->op.same_as(builtin::tvm_struct_set())) {
ICHECK_EQ(call->args.size(), 4);
int kind = call->args[2].as<IntImmNode>()->value;
std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1], kind);
std::string value = PrintExpr(call->args[3]);
std::string cast;
if (kind == builtin::kArrStrides) {
// cast void* to int64_t*
cast = call->args[3]->dtype.is_handle() ? "(int64_t*)" : "";
} else if (kind == builtin::kArrDeviceType) {
// cast int to enum
cast = "(DLDeviceType)";
}
this->PrintIndent();
this->stream << ref << " = " << cast << value << ";\n";
return;
}
}
std::string vid = this->PrintExpr(op->value);
if (vid != "") {
this->PrintIndent();
this->stream << vid << ";\n";
}
}
void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) {
ICHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (i != 0) {
os << "|";
}
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
return;
}
if (i == 0) {
os << "((";
PrintType(t, os);
os << ")(";
}
os << value;
if (i != t.lanes() - 1) {
os << ",";
} else {
os << "))";
}
return;
}
void CodeGenC::PrintRestrict(const Var& v, std::ostream& os) {
if (restrict_keyword_.length() != 0) {
os << ' ' << restrict_keyword_;
}
}
static bool CheckOutermostBracketMatch(const std::string& s) {
if (!s.empty() && s.front() == '(' && s.back() == ')') {
size_t len = s.size();
int n_unmatched = 0;
for (size_t i = 0; i < len; ++i) {
if (s[i] == '(') {
n_unmatched++;
} else if (s[i] == ')') {
n_unmatched--;
}
if (n_unmatched == 0) {
return i == len - 1;
}
}
}
return false;
}
} // namespace codegen
} // namespace tvm