blob: 5cb59fd477121d02a295d9cdd86b99dcaa0d0d7c [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/lowered_func.h
* \brief Information about a lowered TVM function.
* This data structure is final step toward codegen.
*/
#ifndef TVM_LOWERED_FUNC_H_
#define TVM_LOWERED_FUNC_H_
#include <ir/FunctionBase.h>
#include <string>
#include "base.h"
#include "expr.h"
#include "tensor.h"
#include "tvm/node/container.h"
namespace tvm {
// Internal node container of lowered function.
class LoweredFuncNode;
/*!
* \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen.
*/
class LoweredFunc : public FunctionRef {
public:
LoweredFunc() {}
explicit LoweredFunc(NodePtr<Node> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const LoweredFuncNode* operator->() const;
/*! \brief specify container node */
using ContainerType = LoweredFuncNode;
};
/*! \brief specific type of lowered function */
enum LoweredFuncType : int {
/*! \brief Function that can mix device and host calls */
kMixedFunc = 0,
/*! \brief Only contains host code */
kHostFunc = 1,
/*! \brief Only contains device code */
kDeviceFunc = 2
};
/*! \brief Node container of LoweredFunc */
class LoweredFuncNode : public FunctionBaseNode {
public:
/*! \brief The name of the function */
std::string name;
/*!
* \brief The arguments of the function
* This function can only take pod type(int, float) and void* as arguments.
*/
Array<Var> args;
/*!
* \brief The IterVar axis of threads
* Each axis need host function to specify a size.
* \note Calling convention into LoweredFunc
*
* Assume we have a LoweredFunc f, a call into f
* Call(f, arg1, arg2, ..., arg_n,
* size_axis_1, size_axis_2, ... size_axis_m)
*
* Here n = len(args), m = len(thread_axis)
*
* The CodeGen should take this and translate this call
* to corresponding API specific kernel launchs or function calls.
*/
Array<IterVar> thread_axis;
/*!
* \brief The hint data type of Var handles defined in LetStmt
* Can be used as hint when generating type signiture.
* The creation rule is given by
* handle_data_type[var_handle] = make_const(the_type, 0);
*
* \note Expr is used instead Type, because Type cannot be hold by Map.
* constant Expr of given type is used.
*/
Map<Var, Expr> handle_data_type;
/*! \brief The type of the function */
LoweredFuncType func_type{kMixedFunc};
/*! \brief Whether this function is packed function */
bool is_packed_func{true};
/*!
* \brief Whether function ensures that argument pointers do not alias.
* This corresponds to restrict keyword in C.
*/
bool is_restricted{true};
/*! \brief The body statment of the function */
Stmt body;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
// there is no return value, but return 1
// to enable Call into this function.
int num_outputs() const final {
return 1;
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis);
v->Visit("handle_data_type", &handle_data_type);
v->Visit("func_type", &func_type);
v->Visit("is_packed_func", &is_packed_func);
v->Visit("is_restricted", &is_restricted);
v->Visit("body", &body);
}
static constexpr const char* _type_key = "LoweredFunc";
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode, Node);
};
// Implementations of inline functions
inline const LoweredFuncNode* LoweredFunc::operator->() const {
return static_cast<const LoweredFuncNode*>(node_.get());
}
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::LoweredFunc> {
std::size_t operator()(const ::tvm::LoweredFunc& k) const {
return k.hash();
}
};
}
#endif // TVM_LOWERED_FUNC_H_