blob: 8c1c83915cd19861e13c4a60624056b2bf6c1873 [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/build_module.h
* \brief Functions for compiling ops.
*/
#ifndef TVM_BUILD_MODULE_H_
#define TVM_BUILD_MODULE_H_
#include <string>
#include <vector>
#include <utility>
#include "runtime/packed_func.h"
#include "schedule_pass.h"
#include "lowered_func.h"
namespace tvm {
/*!
* \brief Container for target device information.
* Use target::llvm, target::cuda etc functions instead of constructing directly.
*/
class TargetNode : public Node {
public:
/*! \brief The name of the target device */
std::string target_name;
/*! \brief The name of the target device */
std::string device_name;
/*! \brief The type of the target device */
int device_type;
/*! \brief The maximum threads that a schedule should use for this device */
int max_num_threads = 1;
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief Keys for this target */
Array<Expr> keys_array;
/*! \brief Options for this target */
Array<Expr> options_array;
/*! \brief Collection of imported libs */
Array<Expr> libs_array;
/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("target_name", &target_name);
v->Visit("device_name", &device_name);
v->Visit("device_type", &device_type);
v->Visit("max_num_threads", &max_num_threads);
v->Visit("thread_warp_size", &thread_warp_size);
v->Visit("keys_array", &keys_array);
v->Visit("options_array", &options_array);
v->Visit("libs_array", &libs_array);
}
/*! \brief Get the keys for this target as a vector of string */
TVM_DLL std::vector<std::string> keys() const;
/*! \brief Get the options for this target as a vector of string */
TVM_DLL std::vector<std::string> options() const;
/*! \brief Get the keys for this target as an unordered_set of string */
TVM_DLL std::unordered_set<std::string> libs() const;
static constexpr const char* _type_key = "Target";
TVM_DECLARE_NODE_TYPE_INFO(TargetNode, Node);
private:
/*! \brief Internal string repr. */
mutable std::string str_repr_;
};
class Target : public NodeRef {
public:
Target() {}
explicit Target(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief Create a Target given a string
* \param target_str the string to parse
*/
TVM_DLL static Target create(const std::string& target_str);
/*!
* \brief Push a new target context onto the thread local stack. The Target on top of
* the stack is used to determine which specialization to use when invoking a GenericFunc.
* \param target The target to set as the current context.
*/
TVM_DLL static void EnterTargetScope(const tvm::Target& target);
/*!
* \brief Pop a target off the thread local context stack, restoring the previous target
* as the current context.
*/
TVM_DLL static void ExitTargetScope();
/*!
* \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an
* undefined Target will be returned. Otherwise, an empty context stack will cause a
* runtime error.
* \return The target that is the current context. The target may not be defined if
* allow_not_defined is true.
*/
TVM_DLL static tvm::Target current_target(bool allow_not_defined = true);
inline const TargetNode* operator->() const {
return static_cast<const TargetNode*>(node_.get());
}
using ContainerType = TargetNode;
};
/*!
* \brief RAII container to provide a scoped target context. Pushes a target onto the
* context stack when constructed, and pops it when destructed.
*/
struct TargetContext {
/*!
* \brief Enter a new target context. The given target becomes the new current context.
* When the TargetContext is destructed, the previous context is restored.
* \param target The target to set as the new current context.
*/
explicit TargetContext(const tvm::Target& target) {
Target::EnterTargetScope(target);
}
/*! \brief Destructor. Pops the context off the thread local stack. */
~TargetContext() {
Target::ExitTargetScope();
}
};
/*! \brief This namespace provides functions to construct Target instances */
namespace target {
/*! \return A target for LLVM */
TVM_DLL Target llvm(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for CUDA */
TVM_DLL Target cuda(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for ROCm */
TVM_DLL Target rocm(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for OpenCL */
TVM_DLL Target opencl(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for Metal */
TVM_DLL Target metal(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for rasp */
TVM_DLL Target rasp(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for Mali */
TVM_DLL Target mali(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for Intel Graphics */
TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for stackvm */
TVM_DLL Target stackvm(const std::vector<std::string>& options =
std::vector<std::string>());
} // namespace target
class BuildConfig;
/*!
* \brief Container for build configuration options
*/
class BuildConfigNode : public Node {
public:
/*!
* \brief The data alignment to use when constructing buffers. If this is set to
* -1, then TVM's internal default will be used
*/
int data_alignment = -1;
/*!
* \brief The offset factor to use when constructing buffers. If this is set to
* 0, then the offset field is not used.
*/
int offset_factor = 0;
/*!
* \brief Splitting factor for loop splitting. If this is set to zero, no splitting will be
* done. Otherwise, a split will be done with this factor and the inner loop will be unrolled.
*/
int double_buffer_split_loop = 1;
/*! \brief Threshold of number of steps in the loop to be automatically unrolled */
int auto_unroll_max_step = 0;
/*! \brief The maximum nested level of loops that can be automatically unrolled */
int auto_unroll_max_depth = 8;
/*! \brief The maximum extent of loop that will be unrolled */
int auto_unroll_max_extent = 0;
/*!
* \brief Whether to explicitly unroll the loop. If set to false, the unroll hint will
* be passed to the CodeGen phase. Set to true if CodeGen supports unroll pragma.
*/
bool unroll_explicit = true;
/*! \brief Set to true if buffer arguments do not overlap. This enables more optimization. */
bool restricted_func = true;
/*! \brief Whether to detect global barrier */
bool detect_global_barrier = false;
/*! \brief Whether to partition const loop */
bool partition_const_loop = false;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false;
/*! \brief Whether to instrument loads and stores with check for out of the bounds. */
bool instrument_bound_checkers = false;
/*! \brief Whether to disable select rewriting. */
bool disable_select_rewriting = false;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
v->Visit("auto_unroll_max_step", &auto_unroll_max_step);
v->Visit("auto_unroll_max_depth", &auto_unroll_max_depth);
v->Visit("auto_unroll_max_extent", &auto_unroll_max_extent);
v->Visit("unroll_explicit", &unroll_explicit);
v->Visit("restricted_func", &restricted_func);
v->Visit("detect_global_barrier", &detect_global_barrier);
v->Visit("partition_const_loop", &partition_const_loop);
v->Visit("dump_pass_ir", &dump_pass_ir);
v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
v->Visit("disable_select_rewriting", &disable_select_rewriting);
}
static constexpr const char* _type_key = "BuildConfig";
TVM_DECLARE_NODE_TYPE_INFO(BuildConfigNode, Node);
};
/*!
* \brief Container for build configuration options
*/
class BuildConfig : public ::tvm::NodeRef {
public:
BuildConfig() {}
explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {}
const BuildConfigNode* operator->() const {
return static_cast<const BuildConfigNode*>(node_.get());
}
BuildConfigNode* operator->() {
return static_cast<BuildConfigNode*>(node_.get());
}
/*!
* \brief Push a new BuildConfig context onto the thread local stack.
* \param build_config The configuration to set as the current context.
*/
TVM_DLL static void EnterBuildConfigScope(const tvm::BuildConfig& build_config);
/*!
* \brief Pop a build config off the thread local context stack, restoring the previous
* configuration as the current context.
*/
TVM_DLL static void ExitBuildConfigScope();
/*!
* \brief Get the current BuildConfig context from thread local storage, or a default
* configuration if a BuildConfig scope has not been entered.
* \return The configuration that is the current context.
*/
TVM_DLL static tvm::BuildConfig Current();
using ContainerType = BuildConfigNode;
};
/*!
* \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the
* context stack when constructed, and pops it when destructed.
*/
struct BuildConfigContext {
/*!
* \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current
* context. When the BuildConfigContext is destructed, the previous context is restored.
* \param build_config The BuildConfig to set as the new current context.
*/
explicit BuildConfigContext(const tvm::BuildConfig& build_config) {
BuildConfig::EnterBuildConfigScope(build_config);
}
/*! \brief Destructor. Pops the context off the thread local stack. */
~BuildConfigContext() {
BuildConfig::ExitBuildConfigScope();
}
};
/*!
* \brief Construct a BuildConfig containing a new BuildConfigNode
* \return The new BuildConfig
*/
TVM_DLL BuildConfig build_config();
/*!
* \brief Build a LoweredFunc given a schedule, args and binds
* \param sch The schedule to lower.
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param config The build configuration.
* \return The lowered function.
*/
TVM_DLL Array<LoweredFunc> lower(Schedule sch,
const Array<Tensor>& args,
const std::string& name,
const std::unordered_map<Tensor, Buffer>& binds,
const BuildConfig& config);
/*!
* \brief Build a device and host module for a specific target from an array of lowered functions.
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The built module.
*/
TVM_DLL runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);
class GenericFuncNode;
/*!
* \brief Generic function that can be specialized on a per-target basis.
*/
class GenericFunc : public NodeRef {
public:
GenericFunc() {}
explicit GenericFunc(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief Set the default function implementaiton.
* \param value The default function
* \param allow_override If true, this call may override a previously registered function. If
* false, an error will be logged if the call would override a previously registered function.
* \return reference to self.
*/
TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value,
bool allow_override = false);
/*!
* \brief Register a specialized function
* \param tags The tags for this specialization
* \param value The specialized function
* \param allow_override If true, this call may override previously registered tags. If false,
* an error will be logged if the call would override previously registered tags.
* \return reference to self.
*/
TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
const runtime::PackedFunc value,
bool allow_override = false);
/*!
* \brief Call generic function by directly passing in unpacked format.
* \param args Arguments to be passed.
* \tparam Args arguments to be passed.
*
* \code
* // Example code on how to call generic function
* void CallGeneirc(GenericFunc f) {
* // call like normal functions by pass in arguments
* // return value is automatically converted back
* int rvalue = f(1, 2.0);
* }
* \endcode
*/
template<typename... Args>
inline runtime::TVMRetValue operator()(Args&& ...args) const;
/*!
* \brief Invoke the relevant function for the current target context, set by set_target_context.
* Arguments are passed in packed format.
* \param args The arguments to pass to the function.
* \param ret The return value
*/
TVM_DLL void CallPacked(runtime::TVMArgs args,
runtime::TVMRetValue* ret) const;
/*!
* \brief Find or register the GenericFunc instance corresponding to the give name
* \param name The name of the registered GenericFunc
* \return The GenericFunc instance
*/
TVM_DLL static GenericFunc Get(const std::string& name);
/*!
* \brief Add a GenericFunc instance to the registry
* \param func The GenericFunc instance
* \param name The name of the registered GenericFunc
*/
TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name);
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline GenericFuncNode* operator->();
// declare container type
using ContainerType = GenericFuncNode;
// Internal class.
struct Manager;
private:
friend struct Manager;
};
template<typename... Args>
inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize];
int type_codes[kArraySize];
runtime::detail::for_each(TVMArgsSetter(values, type_codes),
std::forward<Args>(args)...);
runtime::TVMRetValue rv;
CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv);
return rv;
}
/*!
* \brief Represents a generic function that can be specialized on a per-target basis.
*/
class GenericFuncNode : public Node {
public:
/*! \brief name of the function */
std::string name_;
/* \brief the generic builder */
runtime::PackedFunc generic_func_;
/* \brief map from keys to registered functions */
std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
static constexpr const char* _type_key = "GenericFunc";
TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node);
};
inline GenericFuncNode* GenericFunc::operator->() {
return static_cast<GenericFuncNode*>(node_.get());
}
#define TVM_GENERIC_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM
/*!
* \def TVM_REGISTER_GENERIC_FUNC
* \brief Register a new generic function, or set a device-specific variant
* of the corresponding function.
*
* \param name The name of the function
*/
#define TVM_REGISTER_GENERIC_FUNC(name) \
TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::GenericFunc::Get(#name)
} // namespace tvm
#endif // TVM_BUILD_MODULE_H_