blob: 5e3b66b3ae15ad15ea160d6479314f26390df8c8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
*/
#include "compile_engine.h"
#include <tvm/driver/driver_api.h>
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/topi/tags.h>
#include <functional>
#include <limits>
#include <mutex>
#include <unordered_map>
#include <utility>
#include <vector>
#include "../transforms/pass_utils.h"
#include "utils.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
TVM_REGISTER_NODE_TYPE(CachedFuncNode);
TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
TVM_REGISTER_NODE_TYPE(CCacheValueNode);
TVM_REGISTER_OBJECT_TYPE(CompileEngineNode);
LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
auto n = make_object<LoweredOutputNode>();
n->outputs = std::move(outputs);
n->implementation = std::move(impl);
data_ = std::move(n);
}
CCacheKey::CCacheKey(Function source_func, Target target) {
auto n = make_object<CCacheKeyNode>();
n->source_func = std::move(source_func);
n->target = std::move(target);
data_ = std::move(n);
}
Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// for now, we always use int32 shape when possible
// even if the result of shape inference becomes int64.
Array<IndexExpr> res;
for (IndexExpr val : shape) {
const int64_t* pval = tir::as_const_int(val);
if (pval != nullptr) {
#ifndef TVM_INDEX_DEFAULT_I64
ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
res.push_back(IntImm(DataType::Int(32), *pval));
#else
res.push_back(val);
#endif // TVM_INDEX_DEFAULT_I64
} else if (val->IsInstance<tir::AnyNode>()) {
res.push_back(val.as<tir::AnyNode>()->ToVar());
} else {
res.push_back(val);
}
}
return res;
}
// The getter to get schedule from compile engine.
// Get schedule from functor.
class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
explicit ScheduleGetter(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {
// Whether to use auto_scheduler schedule.
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
}
CachedFunc Create(const Function& prim_func) {
auto cache_node = make_object<CachedFuncNode>();
cache_node->target = target_;
for (Var param : prim_func->params) {
Array<tvm::te::Tensor> inputs;
if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
inputs.push_back(tensor);
} else {
// flatten tuple of tensor type.
const auto* tuple_type = param->type_as<TupleTypeNode>();
for (Type field : tuple_type->fields) {
const auto* ttype = field.as<TensorTypeNode>();
// TODO(@icemelon): Allow recursive tuple
ICHECK(ttype != nullptr);
tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
inputs.push_back(tensor);
}
}
memo_[param] = inputs;
}
readable_name_stream_ << "fused";
cache_node->outputs = this->VisitExpr(prim_func->body);
auto candidate_name = readable_name_stream_.str();
constexpr static size_t kMaxFuncNameLength = 80;
if (candidate_name.size() > kMaxFuncNameLength) {
std::stringstream truncated_name;
truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
candidate_name = truncated_name.str();
}
cache_node->func_name = candidate_name;
ICHECK(anchor_op_.defined());
// Fusion over tupled results may leave identity relationships
// between inputs and outputs, and those should not be scheduled.
// Hence schedule only non PlaceholderOp outputs.
tvm::Array<te::Tensor> tensor_outs;
for (const auto& tensor : cache_node->outputs) {
if (!tensor->op.as<te::PlaceholderOpNode>()) {
tensor_outs.push_back(tensor);
}
}
te::Schedule schedule;
// No need to register schedule for device copy op.
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
if (use_auto_scheduler_) {
const auto* fauto_schedule =
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
ICHECK(fauto_schedule != nullptr)
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
ObjectRef obj = (*fauto_schedule)(String(cache_node->func_name), tensor_outs);
if (obj.defined()) {
schedule = Downcast<te::Schedule>(obj);
}
}
// Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule.
if (!schedule.defined()) {
ICHECK(anchor_implementation_.defined());
schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
}
for (const auto& scalar : scalars_) {
if (schedule->Contain(scalar)) {
schedule[scalar].compute_inline();
}
}
}
cache_node->schedule = std::move(schedule);
return CachedFunc(cache_node);
}
Array<te::Tensor> VisitExpr_(const VarNode* op) final {
LOG(FATAL) << "Free variable " << op->name_hint();
return {};
}
Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
using tir::make_const;
ICHECK(op->is_scalar());
void* data = op->data->data;
DataType dtype = DataType(op->data->dtype);
auto value = te::compute(
{},
[&](const Array<tvm::tir::Var>&) {
if (dtype == DataType::Int(32)) {
return make_const(dtype, static_cast<const int32_t*>(data)[0]);
} else if (dtype == DataType::Int(64)) {
return make_const(dtype, static_cast<const int64_t*>(data)[0]);
} else if (dtype == DataType::Float(32)) {
return make_const(dtype, static_cast<const float*>(data)[0]);
} else if (dtype == DataType::Float(64)) {
return make_const(dtype, static_cast<const double*>(data)[0]);
} else if (dtype == DataType::Bool()) {
return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
} else {
LOG(FATAL) << "not handled";
return tvm::PrimExpr();
}
},
"compile_engine_const", topi::kBroadcast);
scalars_.push_back(value->op);
return {value};
}
Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
Array<te::Tensor> inputs;
int count_tuple = 0;
for (Expr arg : call_node->args) {
if (arg->checked_type().as<TupleTypeNode>()) {
++count_tuple;
}
for (te::Tensor tensor : VisitExpr(arg)) {
inputs.push_back(tensor);
}
}
if (count_tuple) {
ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input";
}
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
Array<te::Tensor> outputs;
OpImplementation impl;
// Skip fcompute for device copy operators as it is not registered.
if (op == device_copy_op_) {
const auto* copy_input = inputs[0].operator->();
outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0));
} else {
LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
outputs = lowered_out->outputs;
impl = lowered_out->implementation;
}
int op_pattern = fpattern[op];
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
<< " anchor=" << anchor_op_ << " current=" << op;
}
if (op_pattern > anchor_op_pattern_) {
anchor_op_ = op;
anchor_attrs_ = call_node->attrs;
anchor_op_pattern_ = op_pattern;
anchor_implementation_ = impl;
}
if (outputs.size() != 1) {
const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
ICHECK(tuple_type) << "Expect output to be a tuple type";
ICHECK_EQ(tuple_type->fields.size(), outputs.size());
}
// Set the name to `__copy`. It will be detected in graph executor to perform
// data copy across devices.
if (op == device_copy_op_) {
readable_name_stream_.str(std::string());
readable_name_stream_ << "__copy";
} else {
readable_name_stream_ << '_' << op->name;
}
return outputs;
}
Array<te::Tensor> VisitExpr_(const FunctionNode* op) final {
LOG(FATAL) << "Do not support sub function";
return Array<te::Tensor>();
}
Array<te::Tensor> VisitExpr_(const LetNode* op) final {
Array<te::Tensor> val = VisitExpr(op->value);
ICHECK(!memo_.count(op->var));
memo_[op->var] = val;
return VisitExpr(op->body);
}
Array<te::Tensor> VisitExpr_(const TupleNode* op) final {
Array<te::Tensor> fields;
for (Expr field : op->fields) {
ICHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple of Tensor";
Array<te::Tensor> res = VisitExpr(field);
ICHECK_EQ(res.size(), 1);
fields.push_back(res[0]);
}
return fields;
}
Array<te::Tensor> VisitExpr_(const TupleGetItemNode* op) final {
const auto* tuple_type = op->tuple->type_as<TupleTypeNode>();
Array<te::Tensor> tuple = VisitExpr(op->tuple);
ICHECK_EQ(tuple_type->fields.size(), tuple.size());
ICHECK_GE(op->index, 0);
ICHECK_LT(static_cast<size_t>(op->index), tuple.size());
return {tuple[op->index]};
}
private:
tvm::Target target_;
Op anchor_op_;
Attrs anchor_attrs_;
int anchor_op_pattern_{-1};
OpImplementation anchor_implementation_;
std::ostringstream readable_name_stream_;
Array<te::Operation> scalars_;
bool use_auto_scheduler_;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const Op& device_copy_op_;
};
/*!
* \brief Create schedule for target.
* \param source_func The primitive function to be lowered.
* \param target The target we want to create schedule for.
* \return Pair of schedule and cache.
* The funcs field in cache is not yet populated.
*/
CachedFunc CreateSchedule(const Function& source_func, const Target& target) {
return ScheduleGetter(target).Create(source_func);
}
// Creates shape function from functor.
class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
MakeShapeFunc() {}
std::pair<te::Schedule, CachedFunc> Create(const Function& prim_func) {
for (auto param : prim_func->params) {
param_states_[param] = kNoNeed;
Array<tvm::te::Tensor> data_inputs;
Array<tvm::te::Tensor> shape_inputs;
auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) {
// Add data placeholder
Shape shape = GetShape(ttype->shape);
tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype);
data_inputs.push_back(data_tensor);
// Add shape placeholder
int64_t ndim = shape.size();
Shape sshape;
if (ndim > 0) {
sshape.push_back(tvm::Integer(ndim));
}
tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64));
shape_inputs.push_back(shape_tensor);
};
if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
add_placeholder(ttype);
} else {
// flatten tuple of tensor type.
const auto* tuple_type = param->type_as<TupleTypeNode>();
// TODO(@icemelon): Support recursive tuple
ICHECK(tuple_type);
for (Type field : tuple_type->fields) {
const auto* ttype = field.as<TensorTypeNode>();
ICHECK(ttype);
add_placeholder(ttype);
}
}
param_data_[param] = data_inputs;
param_shapes_[param] = shape_inputs;
}
readable_name_stream_ << "shape_func";
auto cache_node = make_object<CachedFuncNode>();
cache_node->outputs = VisitExpr(prim_func->body);
auto candidate_name = readable_name_stream_.str();
constexpr static size_t kMaxFuncNameLength = 80;
if (candidate_name.size() > kMaxFuncNameLength) {
std::stringstream truncated_name;
truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
candidate_name = truncated_name.str();
}
cache_node->func_name = candidate_name;
// set inputs
for (auto param : prim_func->params) {
int state = param_states_[param];
cache_node->shape_func_param_states.push_back(IntImm(DataType::Int(32), state));
if (state & kNeedInputData) {
for (auto t : param_data_[param]) {
cache_node->inputs.push_back(t);
}
}
if (state & kNeedInputShape) {
for (auto t : param_shapes_[param]) {
cache_node->inputs.push_back(t);
}
}
}
CachedFunc cfunc(cache_node);
// generate schedule for shape func
Array<te::Operation> out_ops;
for (auto t : cache_node->outputs) {
out_ops.push_back(t->op);
}
auto schedule = te::create_schedule(out_ops);
tvm::te::AutoInlineInjective(schedule);
for (const auto& scalar : scalars_) {
auto scalar_op = scalar->op;
if (schedule->Contain(scalar_op)) {
schedule[scalar_op].compute_inline();
}
}
return std::make_pair(schedule, cfunc);
}
Array<te::Tensor> VisitExpr(const Expr& expr) final {
if (expr.as<VarNode>()) {
// Do not memoize vars because shape functions could use either the data
// or the shape of a var each time.
return ExprFunctor::VisitExpr(expr);
}
// For other case, do memoized visit
return backend::MemoizedExprTranslator<Array<te::Tensor>>::VisitExpr(expr);
}
Array<te::Tensor> VisitExpr_(const VarNode* var_node) final {
auto var = GetRef<Var>(var_node);
auto it = param_states_.find(var);
if (it == param_states_.end()) {
LOG(FATAL) << "Free variable " << var->name_hint();
return {};
} else {
ICHECK(data_dependents_per_input_.size());
auto data_dependent = data_dependents_per_input_.back();
if (data_dependent) {
param_states_[var] |= kNeedInputData;
return param_data_[var];
} else {
param_states_[var] |= kNeedInputShape;
return param_shapes_[var];
}
}
}
Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
using tir::make_const;
ICHECK(data_dependents_per_input_.size());
bool data_dependent = data_dependents_per_input_.back();
if (!op->is_scalar()) {
// This is a constant weight, extract the shape of the weight tensor.
// This can not be data dependent.
CHECK(!data_dependent);
auto ttype = op->checked_type().as<TensorTypeNode>();
int ndim = static_cast<int>(ttype->shape.size());
Array<PrimExpr> out_shape{ndim};
te::Tensor value = tvm::te::compute(
out_shape,
[&](const Array<tvm::tir::Var>& indices) {
auto idx = indices[0];
PrimExpr ret = make_const(DataType::Int(64), 0);
for (int i = 0; i < ndim; i++) {
ret = tvm::if_then_else(idx == i, ttype->shape[i], ret);
}
return ret;
},
"shape_const", topi::kBroadcast);
scalars_.push_back(value);
return {value};
}
if (data_dependent) {
void* data = op->data->data;
DataType dtype = DataType(op->data->dtype);
auto value = tvm::te::compute(
{},
[&](const Array<tvm::tir::Var>&) {
if (dtype == DataType::Int(32)) {
return make_const(dtype, static_cast<const int32_t*>(data)[0]);
} else if (dtype == DataType::Int(64)) {
return make_const(dtype, static_cast<const int64_t*>(data)[0]);
} else if (dtype == DataType::Float(32)) {
return make_const(dtype, static_cast<const float*>(data)[0]);
} else if (dtype == DataType::Float(64)) {
return make_const(dtype, static_cast<const double*>(data)[0]);
} else if (dtype == DataType::Bool()) {
return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
} else {
LOG(FATAL) << "not handled";
return tvm::PrimExpr();
}
},
"data_const", topi::kBroadcast);
scalars_.push_back(value);
return {value};
} else {
auto value = tvm::te::compute(
{}, [&](const Array<tvm::tir::Var>&) { return tir::make_const(DataType::Int(64), 0); },
"shape_const", topi::kBroadcast);
scalars_.push_back(value);
return {value};
}
}
Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fshape_func = Op::GetAttrMap<FShapeFunc>("FShapeFunc");
static auto tshape_data_dependent = Op::GetAttrMap<TShapeDataDependent>("TShapeDataDependent");
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back())
<< "Error in op fusion: output of the shape func is fed to a "
<< "data-dependent shape func";
ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name;
ICHECK_GT(tshape_data_dependent.count(op), 0)
<< "Internal error, cannot find TShapeDataDependent for " << op->name;
Array<Integer> dep_spec = tshape_data_dependent[op];
if (dep_spec.size() == 1) {
// This is for cases when data dependence is specified per op
// Replicate 0 or 1 flag to all arguments
for (size_t i = 1; i < call_node->args.size(); ++i) {
dep_spec.push_back(dep_spec[0]);
}
}
// Visit all inputs
Array<te::Tensor> inputs;
int count_tuple = 0;
for (size_t i = 0; i < call_node->args.size(); ++i) {
Expr arg = call_node->args[i];
if (arg->checked_type().as<TupleTypeNode>()) {
++count_tuple;
}
data_dependents_per_input_.push_back(dep_spec[i]->value != 0);
for (te::Tensor tensor : VisitExpr(arg)) {
inputs.push_back(tensor);
}
data_dependents_per_input_.pop_back();
}
if (count_tuple) {
ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input";
}
// Get output ndims
auto ret_type = call_node->checked_type();
Array<IndexExpr> out_ndims;
if (const auto* ttype = ret_type.as<TensorTypeNode>()) {
out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size()));
} else {
auto rtype = ret_type.as<TupleTypeNode>();
// TODO(@icemelon): Allow recursive tuple
ICHECK(rtype);
for (size_t i = 0; i < rtype->fields.size(); ++i) {
auto ttype = rtype->fields[i].as<TensorTypeNode>();
ICHECK(ttype);
out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size()));
}
}
// Call shape function
auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims);
readable_name_stream_ << "_" << op->name;
return outputs;
}
Array<te::Tensor> VisitExpr_(const FunctionNode* op) final {
LOG(FATAL) << "Do not support sub function";
return Array<te::Tensor>();
}
Array<te::Tensor> VisitExpr_(const LetNode* op) final {
Array<te::Tensor> val = VisitExpr(op->value);
ICHECK(!memo_.count(op->var));
memo_[op->var] = val;
return VisitExpr(op->body);
}
Array<te::Tensor> VisitExpr_(const TupleNode* op) final {
Array<te::Tensor> fields;
for (Expr field : op->fields) {
ICHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple of Tensor";
Array<te::Tensor> res = VisitExpr(field);
ICHECK_EQ(res.size(), 1);
fields.push_back(res[0]);
}
return fields;
}
Array<te::Tensor> VisitExpr_(const TupleGetItemNode* op) final {
Array<te::Tensor> input_shapes = VisitExpr(op->tuple);
Array<te::Tensor> out;
out.push_back(input_shapes[op->index]);
return out;
}
private:
/*! \brief String stream for function name */
std::ostringstream readable_name_stream_;
/*! \brief Map from parameter to its shape function usage state */
std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> param_states_;
/*! \brief Map from parameter to list of data placeholder */
std::unordered_map<Expr, Array<te::Tensor>, ObjectPtrHash, ObjectPtrEqual> param_data_;
/*! \brief Map from parameter to list of shape placeholder */
std::unordered_map<Expr, Array<te::Tensor>, ObjectPtrHash, ObjectPtrEqual> param_shapes_;
/*! \brief Stack of data dependencies for shape function, specified per each op input */
std::vector<bool> data_dependents_per_input_;
/*! \brief Scalars used in the shape function */
Array<te::Tensor> scalars_;
};
class CompileEngineImpl : public CompileEngineNode {
public:
// Lower the function.
CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
// For now, build one module per function.
PackedFunc JIT(const CCacheKey& key) final {
CCacheValue value = LowerInternal(key);
if (value->packed_func != nullptr) return value->packed_func;
// build the function.
tvm::runtime::Module m;
if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
m = (*f)(value->cached_func->funcs, key->target);
} else {
m = build(value->cached_func->funcs, key->target, Target(nullptr));
}
value->packed_func = m.GetFunction(value->cached_func->func_name);
return value->packed_func;
}
CachedFunc LowerShapeFunc(const CCacheKey& key) final {
return LowerShapeFuncInternal(key)->cached_func;
}
Array<tvm::runtime::Module> LowerExternalFunctions() {
Array<tvm::runtime::Module> ret;
std::unordered_map<std::string, std::string> cached_symbol;
std::vector<CCacheKey> cached_ext_funcs;
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
ICHECK(src_func.defined());
if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
ICHECK(code_gen.defined()) << "No external codegen is set";
std::string code_gen_name = code_gen.value();
cached_ext_funcs.push_back(it.first);
auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false);
std::string sn = symbol_name.value();
if (!cached_symbol.count(sn)) {
cached_symbol[sn] = code_gen_name;
} else {
ICHECK_NE(cached_symbol[sn], code_gen_name)
<< "Found duplicated symbol: " << sn << " for: " << code_gen_name;
}
std::string ext_name = "relay.ext." + code_gen_name;
auto pf = tvm::runtime::Registry::Get(ext_name);
ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
// No need to keep compiler attribute at this point, functions have been
// extracted for specific codegen.
src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
runtime::Module ext_mod = (*pf)(src_func);
ICHECK(ext_mod.defined()) << "No external runtime is generated.";
ret.push_back(ext_mod);
}
}
// No need to cache external functions as we collected them all to create
// external runtime modules.
for (const auto& it : cached_ext_funcs) {
cache_.erase(it);
}
return ret;
}
void Clear() final { cache_.clear(); }
// List all items in the cache.
Array<ObjectRef> ListItems() {
std::lock_guard<std::mutex> lock(mutex_);
Array<ObjectRef> items;
for (auto& kv : cache_) {
items.push_back(kv.first);
items.push_back(kv.second);
}
return items;
}
// List all items in the shape_func_cache.
Array<ObjectRef> ListShapeFuncItems() {
std::lock_guard<std::mutex> lock(mutex_);
Array<ObjectRef> items;
for (auto& kv : shape_func_cache_) {
items.push_back(kv.first);
items.push_back(kv.second);
}
return items;
}
/*!
* \brief Get the cache key of the function that is being lowered currently
* \return the cache key
*/
CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
private:
// implement lowered func
CCacheValue LowerInternal(const CCacheKey& key) {
std::lock_guard<std::mutex> lock(mutex_);
CCacheValue value;
auto it = cache_.find(key);
if (it != cache_.end()) {
it->second->use_count += 1;
if (it->second->cached_func.defined()) return it->second;
value = it->second;
} else {
value = CCacheValue(make_object<CCacheValueNode>());
value->use_count = 0;
if (!backend::IsCompileEngineCacheDisabled()) {
cache_[key] = value;
}
}
cur_ccache_key_ = key;
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
auto cache_node = make_object<CachedFuncNode>();
const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
cache_node->func_name = std::string(name_node.value());
cache_node->target = Target("ext_dev");
cache_node->funcs->Add(GlobalVar(cache_node->func_name), key->source_func);
value->cached_func = CachedFunc(cache_node);
return value;
}
// Enforce use the target.
With<Target> target_scope(key->target);
ICHECK(!value->cached_func.defined());
auto cfunc = CreateSchedule(key->source_func, key->target);
auto cache_node = make_object<CachedFuncNode>(*(cfunc.operator->()));
// Skip lowering for device copy node.
const Expr body = (key->source_func)->body;
if (const CallNode* call_node = body.as<CallNode>()) {
if (call_node->attrs.as<DeviceCopyAttrs>()) {
value->cached_func = CachedFunc(cache_node);
return value;
}
}
cache_node->func_name = GetUniqueName(cache_node->func_name);
// NOTE: array will copy on write.
Array<te::Tensor> all_args = cache_node->inputs;
for (te::Tensor arg : cache_node->outputs) {
all_args.push_back(arg);
}
// lower the function
if (const auto* f = runtime::Registry::Get("relay.backend.lower")) {
cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func);
} else {
using tvm::transform::PassContext;
With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
std::unordered_map<te::Tensor, tir::Buffer> binds;
cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds);
}
value->cached_func = CachedFunc(cache_node);
return value;
}
// implement lowered shape func
CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
std::lock_guard<std::mutex> lock(mutex_);
CCacheValue value;
auto it = shape_func_cache_.find(key);
if (it != shape_func_cache_.end()) {
it->second->use_count += 1;
if (it->second->cached_func.defined()) return it->second;
value = it->second;
} else {
value = CCacheValue(make_object<CCacheValueNode>());
value->use_count = 0;
shape_func_cache_[key] = value;
}
// Enforce use the target.
With<Target> target_scope(key->target);
ICHECK(!value->cached_func.defined());
auto spair = MakeShapeFunc().Create(key->source_func);
auto cache_node = make_object<CachedFuncNode>(*(spair.second.operator->()));
cache_node->func_name = GetUniqueName(cache_node->func_name);
cache_node->target = key->target;
Array<te::Tensor> all_args = cache_node->inputs;
for (te::Tensor arg : cache_node->outputs) {
all_args.push_back(arg);
}
using tvm::transform::PassContext;
With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
std::unordered_map<te::Tensor, tir::Buffer> binds;
cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds);
value->cached_func = CachedFunc(cache_node);
return value;
}
/*!
* \brief Get unique name from name.
* \param name The orginal name.
* \return Updated name which is unique.
*/
std::string GetUniqueName(std::string name) {
for (size_t i = 0; i < name.length(); ++i) {
if (name[i] == '.') name[i] = '_';
}
while (true) {
auto it = name_map_.find(name);
if (it == name_map_.end()) {
name_map_[name] = 1;
return name;
} else {
std::ostringstream os;
os << name << "_" << it->second;
++(it->second);
name = os.str();
}
}
return name;
}
/*! \brief compiler cache lock*/
std::mutex mutex_;
/*! \brief internal name map to get an unique name */
std::unordered_map<std::string, int> name_map_;
/*! \brief internal compiler cache */
std::unordered_map<CCacheKey, CCacheValue> cache_;
/*! \brief internal compiler cache for shape funcs */
std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
/*! \brief the cache key of the function that is being lowered currently*/
CCacheKey cur_ccache_key_;
};
/*! \brief The global compile engine */
CompileEngine& CompileEngine::Global() {
// intentionally allocate raw pointer to avoid
// free during destructuion.
static CompileEngine* inst = new CompileEngine(make_object<CompileEngineImpl>());
return *inst;
}
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.disable_compile_engine_cache", Bool);
TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput")
.set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) {
return LoweredOutput(outputs, impl);
});
TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
.set_body_typed([](Function source_func, Target target) {
return CCacheKey(source_func, target);
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal").set_body_typed([]() {
return CompileEngine::Global();
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](CompileEngine self) {
self->Clear();
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower")
.set_body_typed([](CompileEngine self, CCacheKey key) { return self->Lower(key); });
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc")
.set_body_typed([](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); });
TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions")
.set_body_typed([](CompileEngine self) { return self->LowerExternalFunctions(); });
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT")
.set_body_typed([](CompileEngine self, CCacheKey key) { return self->JIT(key); });
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems").set_body_typed([](CompileEngine self) {
CompileEngineImpl* ptr = dynamic_cast<CompileEngineImpl*>(self.operator->());
ICHECK(ptr != nullptr);
return ptr->ListItems();
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListShapeFuncItems")
.set_body_typed([](CompileEngine self) {
CompileEngineImpl* ptr = dynamic_cast<CompileEngineImpl*>(self.operator->());
ICHECK(ptr != nullptr);
return ptr->ListShapeFuncItems();
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGetCurrentCCacheKey")
.set_body_typed([](CompileEngine self) {
CompileEngineImpl* ptr = dynamic_cast<CompileEngineImpl*>(self.operator->());
ICHECK(ptr != nullptr);
return ptr->GetCurrentCCacheKey();
});
} // namespace relay
} // namespace tvm