blob: 791fe4a01596588adef992c91c8fa131adfd3322 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_c.h
* \brief Common utilities to generated C style code.
*/
#ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_
#define TVM_TARGET_SOURCE_CODEGEN_C_H_
#include <tvm/ir/op.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "../../tir/transform/ir_utils.h"
#include "codegen_source_base.h"
namespace tvm {
namespace codegen {
using namespace tir;
/*!
* \brief A base class to generate C code.
*
* CodeGenC have two modes: generate SSA formed C code or normal form.
*
* **NOTE** CodeGenC does not aim at generating C codes consumed by MSVC or GCC,
* Rather, it's providing infrastructural abstraction for C variants like CUDA
* and OpenCL-C. You might find some odd variant features, e.g., type `int3` for
* a vector of 3 `int`s. For native C code generator, see `CodeGenLLVM`.
*/
class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
public StmtFunctor<void(const Stmt&)>,
public CodeGenSourceBase {
public:
/*!
* \brief Initialize the code generator.
* \param output_ssa Whether output SSA.
*/
void Init(bool output_ssa);
/*!
* \brief Add the function declaration to the generated module,
* without defining it.
*
* \param gvar The GlobalVar representing the function.
* \param func The function to be compiled.
* \param whether to append return 0 in the end.
*/
virtual void DeclareFunction(const GlobalVar& gvar, const PrimFunc& func);
/*!
* \brief Add the function to the generated module, including its
* declaration and definition.
*
* \param gvar The GlobalVar representing the function.
* \param func The function to be compiled.
*/
virtual void AddFunction(const GlobalVar& gvar, const PrimFunc& func);
/*!
* \brief Get the name of a declared function
* \param gvar The GlobalVar of the function
* \returns The string name of the function
*/
ffi::String GetFunctionName(const GlobalVar& gvar);
/*!
* \brief Finalize the compilation and return the code.
* \return The code.
*/
virtual std::string Finish();
/*!
* \brief Print the Stmt n to CodeGenC->stream
* \param n The statement to be printed.
*/
void PrintStmt(const Stmt& n) { VisitStmt(n); }
/*!
* \brief Print the expression n(or its ssa id if in ssa mode) into os
* \param n The expression to be printed.
* \param os The output stream
*/
void PrintExpr(const PrimExpr& n, std::ostream& os);
/*!
* \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed.
*/
std::string PrintExpr(const PrimExpr& n) {
std::ostringstream os;
PrintExpr(n, os);
return os.str();
}
// The following parts are overloadable print operations.
/*! \brief Print the function signature before the argument list
*
* The default implementation delegates out to PrintFuncPrefix and
* PrintExtraAttrs.
*
* \param function_name The name of the function
*
* \param func The function whose signature should be printed
*
* \param os The output stream
*/
virtual void PrintFunctionSignature(const ffi::String& function_name, const PrimFunc& func,
std::ostream& os);
/*!
* \brief Print the function header before the argument list
* \param os The output stream
*
* Example: stream << "void";
*/
virtual void PrintFuncPrefix(std::ostream& os); // NOLINT(*)
/*!
* \brief Print extra function attributes
*
* Example: __launch_bounds__(256) for CUDA functions
*/
virtual void PrintExtraAttrs(const PrimFunc& f, std::ostream& os); // NOLINT(*)
/*!
* \brief Insert statement before function body.
* \param f The function to be compiled.
*/
virtual void PreFunctionBody(const PrimFunc& f) {}
/*!
* \brief Initialize codegen state for generating f.
* \param f The function to be compiled.
*/
virtual void InitFuncState(const PrimFunc& f);
// expression
void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const BufferLoadNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*)
// statment
void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const WhileNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const DeclBufferNode* op) override;
/*!
* \brief Print expr representing the thread tag
* \param IterVar iv The thread index to be binded;
*/
virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*)
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
virtual void PrintStorageSync(const CallNode* op); // NOLINT(*)
// Binary vector op.
virtual void PrintVecBinaryOp(const std::string& op, DataType op_type, PrimExpr lhs, PrimExpr rhs,
std::ostream& os); // NOLINT(*)
// print vector load
virtual std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base);
// print vector store
virtual void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base,
const std::string& value); // NOLINT(*)
// print load of single element
virtual void PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os); // NOLINT(*)
// print store of single element.
virtual void PrintVecElemStore(const std::string& vec, DataType t, int i,
const std::string& value);
// print vector constructor
virtual void PrintVecConstructor(DataType t, std::ostream& os);
// Get a cast type from to
virtual std::string CastFromTo(std::string value, DataType from, DataType target);
// Get load of single element with expression
virtual void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os);
// Print restrict keyword for a given Var if applicable
virtual void PrintRestrict(const Var& v, std::ostream& os);
virtual void SetConstantsByteAlignment(Integer constants_byte_alignment) {
constants_byte_alignment_ = constants_byte_alignment;
}
protected:
// Print reference to struct location
std::string GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind);
// Print reference to a buffer as type t in index.
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index);
/*!
* \brief Handle volatile loads.
*
* This is to workaround a bug in CUDA cuda_fp16.h. Volatile accesses
* to shared memory are required for reductions. However, __half class
* does not implement volatile member functions. CUDA codegen will cast
* away volatile qualifier from CUDA __half types.
*/
virtual void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
std::ostream& os) {
// By default, do nothing but print the loaded value.
os << value;
}
/*!
* \brief Check if scope is part of type in the target language.
*
* **NOTE** In OpenCL, __local is part of type, so "__local int *"
* is legal. This is not the case for CUDA, where "__shared__"
* or "__constant__" is not part of type but a storage class (like
* C/C++ static).
*/
virtual bool IsScopePartOfType() const { return true; }
/*!
* \brief Generate forward function declarations.
* \param global_symbol The symbolc of the target function.
* \param arg_types The argument types to the function.
* \param ret_type The return type of the function
* \param os The output stream.
*/
virtual void GenerateForwardFunctionDeclarations(ffi::String global_symbol,
const ffi::Array<Type>& arg_types,
const Type& ret_type) {}
/*!
* \brief Print external function call.
* \param ret_type The return type.
* \param global_symbol The symbolc of the target function.
* \param args The arguments to the function.
* \param skip_first_arg Whether to skip the first arguments.
* \param os The output stream.
*/
virtual void PrintCallExtern(Type ret_type, ffi::String global_symbol,
const ffi::Array<PrimExpr>& args, bool skip_first_arg,
std::ostream& os); // NOLINT(*)
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
bool HandleTypeMatch(const VarNode* buf_var, DataType t) const;
/*!
* \brief Register the data type of buf_var
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
void RegisterHandleType(const VarNode* buf_var, DataType t);
// override
void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) override;
/*! \brief reserves common C keywords */
void ReserveKeywordsAsUnique();
/*! \brief Check if buf_var is volatile or not. */
bool IsVolatile(const VarNode* buf_var) const { return volatile_buf_.count(buf_var) != 0; }
/*! \brief restrict keyword */
std::string restrict_keyword_{""};
/*! \brief the storage scope of allocation */
std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const VarNode*, DataType> handle_data_type_;
/*! \brief Record of ops that have pre-defined global symbol. */
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
// cache commonly used ops
const Op& builtin_call_extern_ = builtin::call_extern();
const Op& builtin_call_pure_extern_ = builtin::call_pure_extern();
Integer constants_byte_alignment_ = 16;
/*! \brief whether to print in SSA form */
bool print_ssa_form_{false};
/*! \brief whether the module has a main function declared */
bool has_tvm_ffi_main_func_{false};
private:
/*! \brief set of volatile buf access */
std::unordered_set<const VarNode*> volatile_buf_;
// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;
// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*> let_binding_;
/* \brief Map of GlobalVar to their symbol.
*
* For externally-exposed functions, this is given by the
* tvm::attr::kTarget attribute of the PrimFunc. For internal
* functions, this is the name of the function's GlobalVar, possibly
* altered to prevent duplicate names.
*/
std::unordered_map<GlobalVar, ffi::String> internal_functions_;
/* \brief Name supply to generate unique function names */
NameSupply func_name_supply_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_SOURCE_CODEGEN_C_H_