blob: c9af24a04a3cc5de321917d6da9f910e9a0a993d [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file codegen_c.h
* \brief Common utilities to generated C style code.
*/
#ifndef TVM_CODEGEN_CODEGEN_C_H_
#define TVM_CODEGEN_CODEGEN_C_H_
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/codegen.h>
#include <tvm/lowered_func.h>
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include "codegen_source_base.h"
namespace tvm {
namespace codegen {
using namespace ir;
/*!
* \brief A base class to generate C code.
*
* CodeGenC have two modes: generate SSA formed C code or normal form.
*
* **NOTE** CodeGenC does not aim at generating C codes consumed by MSVC or GCC,
* Rather, it's providing infrastructural abstraction for C variants like CUDA
* and OpenCL-C. You might find some odd variant features, e.g., type `int3` for
* a vector of 3 `int`s. For native C code generator, see `CodeGenLLVM`.
*/
class CodeGenC :
public ExprFunctor<void(const Expr&, std::ostream&)>,
public StmtFunctor<void(const Stmt&)>,
public CodeGenSourceBase {
public:
/*!
* \brief Initialize the code generator.
* \param output_ssa Whether output SSA.
*/
void Init(bool output_ssa);
/*!
* \brief Add the function to the generated module.
* \param f The function to be compiled.
*/
void AddFunction(LoweredFunc f);
/*!
* \brief Finalize the compilation and return the code.
* \return The code.
*/
std::string Finish();
/*!
* \brief Print the Stmt n to CodeGenC->stream
* \param n The statement to be printed.
*/
void PrintStmt(const Stmt& n) {
VisitStmt(n);
}
/*!
* \brief Print the expression n(or its ssa id if in ssa mode) into os
* \param n The expression to be printed.
* \param os The output stream
*/
void PrintExpr(const Expr& n, std::ostream& os);
/*!
* \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed.
*/
std::string PrintExpr(const Expr& n) {
std::ostringstream os;
PrintExpr(n, os);
return os.str();
}
// The following parts are overloadable print operations.
/*!
* \brief Insert statement before function body.
* \param f The function to be compiled.
*/
virtual void PreFunctionBody(LoweredFunc f) {}
/*!
* \brief Initialize codegen state for generating f.
* \param f The function to be compiled.
*/
virtual void InitFuncState(LoweredFunc f);
// expression
void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*)
// statment
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Store* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
/*!
* Print Type represetnation of type t.
* \param t The type representation.
* \param os The stream to print the ctype into
*/
virtual void PrintType(Type t, std::ostream& os); // NOLINT(*)
/*!
* \brief Print expr representing the thread tag
* \param IterVar iv The thread index to be binded;
*/
virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*)
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
virtual void PrintStorageSync(const Call* op); // NOLINT(*)
// Binary vector op.
virtual void PrintVecBinaryOp(
const std::string&op, Type op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
// print vector load
virtual std::string GetVecLoad(Type t, const Variable* buffer, Expr base);
// print vector store
virtual void PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value); // NOLINT(*)
// print load of single element
virtual void PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os); // NOLINT(*)
// print store of single element.
virtual void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value);
// Get a cast type from to
virtual std::string CastFromTo(std::string value, Type from, Type target);
protected:
// Print reference to struct location
std::string GetStructRef(
Type t, const Expr& buffer, const Expr& index, int kind);
// print reference to a buffer as type t in index.
virtual std::string GetBufferRef(
Type t, const Variable* buffer, Expr index);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
bool HandleTypeMatch(const Variable* buf_var, Type t) const;
/*!
* \brief Register the data type of buf_var
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
void RegisterHandleType(const Variable* buf_var, Type t);
// override
void PrintSSAAssign(
const std::string& target, const std::string& src, Type t) final;
/*! \brief restrict keyword */
std::string restrict_keyword_{""};
/*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief reserves common C keywords */
void ReserveKeywordsAsUnique();
private:
/*! \brief whether to print in SSA form */
bool print_ssa_form_{false};
/*! \brief set of volatile buf access */
std::unordered_set<const Variable*> volatile_buf_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_C_H_