blob: 2c78dbca257c858237b9e31076c09aa8dd9cf17b [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relax/transform/realize_vdevice.cc
* \brief Propagate virtual device information.
*/
#include <tvm/ffi/cast.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
namespace tvm {
namespace relax {
namespace {
class VDeviceLookup {
public:
explicit VDeviceLookup(IRModule mod) {
auto opt_global_info = mod->global_infos.Get("vdevice");
if (!opt_global_info) return;
auto downcast_vdevice = [](GlobalInfo info) -> VDevice {
if (auto vdevice = info.as<VDevice>()) {
return vdevice.value();
} else {
TVM_FFI_THROW(TypeError)
<< "Each item in an IRModule's \"vdevice\" annotation must be a VDevice, "
<< "but instead found item of type " << info->GetTypeKey();
}
};
opt_vdevices_ = opt_global_info.value().Map(downcast_vdevice);
}
VDevice operator()(Attrs hint_on_device_attrs) {
auto attrs = hint_on_device_attrs.as<HintOnDeviceAttrs>();
TVM_FFI_ICHECK(attrs);
int32_t device_type = attrs->device_type;
int32_t device_id = attrs->index;
ffi::String memory_scope = attrs->memory_scope;
TVM_FFI_CHECK(opt_vdevices_.defined(), ValueError)
<< "The target VDevice in the GlobalInfos was not found.";
auto vdevices = opt_vdevices_.value();
TVM_FFI_CHECK_GE(device_id, 0, ValueError)
<< "The device id in R.hint_on_device must not be negative";
for (auto vdevice : vdevices) {
int dev_type = vdevice->target->GetTargetDeviceType();
if (dev_type == device_type && vdevice->vdevice_id == device_id &&
memory_scope == vdevice->memory_scope) {
return vdevice;
}
}
TVM_FFI_THROW(ValueError)
<< "Expected to find device with type " << device_id << " and id " << device_id
<< ", but no such device was found in the IRModule's \"vdevice\" annotation";
TVM_FFI_UNREACHABLE();
}
private:
ffi::Optional<ffi::Array<VDevice>> opt_vdevices_ = std::nullopt;
};
class DeviceHintCollector : ExprVisitor {
public:
static std::tuple<ffi::Map<Var, VDevice>, ffi::Map<Var, VDevice>> Collect(IRModule mod) {
DeviceHintCollector visitor{VDeviceLookup(mod)};
for (const auto& [gvar, base_func] : mod->functions) {
if (auto func = base_func.as<Function>()) {
visitor(func.value());
}
}
return {visitor.known_vdevice_, visitor.hint_on_device_inputs_};
}
private:
explicit DeviceHintCollector(VDeviceLookup vdevice_lookup) : vdevice_lookup_(vdevice_lookup) {}
void VisitExpr_(const FunctionNode* func) override {
ExprVisitor::VisitExpr_(func);
std::function<void(Expr, StructInfo)> check_ret_sinfo = [this, &check_ret_sinfo](
Expr expr, StructInfo sinfo) {
// If the function is annotated as returning a tensor on a
// specific device, then that annotation may be propagated into
// the returned variable.
if (auto tensor_info = sinfo.as<TensorStructInfoNode>();
tensor_info && tensor_info->vdevice.defined()) {
if (auto opt_var = expr.as<Var>()) {
auto var = opt_var.value();
if (!known_vdevice_.count(var)) {
known_vdevice_.Set(var, tensor_info->vdevice.value());
}
}
}
// If the function is annotated as returning a tuple of tensors,
// where some elements of the tuple are tensors that exist on a
// specific device, then those annotations may be propagated
// into the corresponding tensor annotations.
if (auto tuple_info = sinfo.as<TupleStructInfoNode>()) {
// The returned tuple is not necessarily an in-line tuple. In
// order to find the variables that are bound to the
// individual tuple elements, we may need to unwrap the
// variable bindings in order to find the tuple itself. This
// unwrapping is not required for the tensor case, as it would
// already be handled when propagating VDevice across variable
// definitions.
while (auto bound_value = LookupBinding(expr)) {
expr = bound_value.value();
}
// Even after unwrapping variable bindings, the resulting
// expression is not required to be a tuple literal. For
// example, the function may return one of its arguments as an
// output, or may return the result of a `relax::Call` that
// produces a tuple of outputs.
if (auto tuple = expr.as<TupleNode>()) {
TVM_FFI_CHECK_EQ(tuple_info->fields.size(), tuple->fields.size(), ValueError)
<< "Function returns a tuple with " << tuple->fields.size() << " elements, "
<< "but is annotated as returning a tuple with " << tuple_info->fields.size()
<< " elements";
for (size_t i = 0; i < tuple->fields.size(); i++) {
check_ret_sinfo(tuple->fields[i], tuple_info->fields[i]);
}
}
}
};
check_ret_sinfo(func->body->body, func->ret_struct_info);
}
void VisitVarDef(const Var& var) override {
if (auto tinfo = var->struct_info_.as<TensorStructInfoNode>();
tinfo && tinfo->vdevice.defined()) {
known_vdevice_.Set(var, tinfo->vdevice.value());
}
ExprVisitor::VisitVarDef(var);
}
void VisitBinding(const Binding& binding) override {
ExprVisitor::VisitBinding(binding);
binding_lookup_.Set(binding->var, GetBoundValue(binding));
}
void VisitBinding_(const VarBindingNode* binding, const CallNode* call) override {
ExprVisitor::VisitBinding_(binding, call);
if (call->op == hint_on_device_op_) {
auto vdevice = vdevice_lookup_(call->attrs);
known_vdevice_.Set(binding->var, vdevice);
TVM_FFI_ICHECK_EQ(call->args.size(), 1);
if (auto arg_var = call->args[0].as<Var>()) {
hint_on_device_inputs_.Set(arg_var.value(), vdevice);
}
}
}
ffi::Optional<Expr> LookupBinding(const Expr& expr) const {
if (auto var = expr.as<Var>()) {
if (auto bound = binding_lookup_.Get(var.value())) {
return bound.value();
}
}
return std::nullopt;
}
// A lookup to identify the VDevice from the IRModule attributes,
// given the device type and device id from the R.hint_on_device
// attributes.
VDeviceLookup vdevice_lookup_;
// A lookup of variable bindings, used to unwrap the variable
// bindings in functions that return a tuple.
ffi::Map<Var, Expr> binding_lookup_;
// A map from Var to the VDevice they are known to occur on. This
// only contains variables whose location is explicitly known
// (e.g. output of `R.hint_on_device`, variables with explicit
// `VDevice` in their struct info), and does not include variables
// whose location is (e.g. input of `R.hint_on_device`).
ffi::Map<Var, VDevice> known_vdevice_;
// A map from Var to the VDevice they are expected to occur on. If
// a variable appears in both `known_vdevice_` and
// `hint_on_device_inputs_`, then `known_vdevice_` takes priority.
//
// For example, `B = R.hint_on_device(A, tvm.cuda(0))` implies that
// `B` must be located on "cuda:0". However, `A` may already have a
// `VDevice` annotation, or may be the output of `R.to_device`.
// Therefore, we only determine that `A` is located on "cuda:0" if
// no other annotation has already provided a known location for
// `A`.
ffi::Map<Var, VDevice> hint_on_device_inputs_;
// The `R.hint_on_device` operator.
const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device");
};
// Utility to determine which Var instances must be located on the
// same VDevice.
class VDeviceSetCollector : ExprVisitor {
public:
static ffi::Map<Var, ffi::Array<Var>> Collect(IRModule mod) {
VDeviceSetCollector visitor;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto func = base_func.as<Function>()) {
visitor(func.value());
}
}
return visitor.var_to_co_located_vars_;
}
private:
void VisitBinding(const Binding& binding) override {
auto cached = current_binding_;
current_binding_ = binding->var;
ExprVisitor::VisitBinding(binding);
current_binding_ = cached;
}
void VisitExpr_(const CallNode* call) override {
if (call->op != to_vdevice_op_ && call->op != hint_on_device_op_) {
ExprVisitor::VisitExpr_(call);
}
}
void VisitExpr_(const VarNode* op) override {
if (current_binding_) {
auto var = ffi::GetRef<Var>(op);
var_to_co_located_vars_[current_binding_.value()].push_back(var);
var_to_co_located_vars_[var].push_back(current_binding_.value());
}
}
ffi::Optional<Var> current_binding_ = std::nullopt;
// Lookup from relax variable to the set of relax variables which
// must be located on the same device. For example, a trivial
// binding `B = A` implies that both `B` and `A` are on the same
// device. Similarly, `C = R.add(A,B)` implies that `A`, `B`, and
// `C` are all on the same device.
//
// In general, variables that are used as part of the same
// `relax::Call` operation must be located on the same device, with
// the exception of `R.hint_on_device` and `R.to_vdevice`, which may
// introduce a transfer across devices.
std::unordered_map<Var, ffi::Array<Var>> var_to_co_located_vars_;
const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device");
const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
};
ffi::Map<Var, VDevice> InferVDevice(IRModule mod) {
auto [explicit_annotations, hint_on_device_args] = DeviceHintCollector::Collect(mod);
auto co_located_var_lookup = VDeviceSetCollector::Collect(mod);
ffi::Map<Var, VDevice> known_vdevice;
std::vector<Var> to_visit;
// A helper function to propagate all `known_vdevice` entries based
// on the connections in `co_located_var_lookup`.
auto propagate = [&]() {
while (to_visit.size()) {
Var visiting = to_visit.back();
to_visit.pop_back();
if (auto upstream_vars = co_located_var_lookup.Get(visiting)) {
auto vdevice = known_vdevice.at(visiting);
for (Var upstream_var : upstream_vars.value()) {
if (!known_vdevice.count(upstream_var)) {
known_vdevice.Set(upstream_var, vdevice);
to_visit.push_back(upstream_var);
}
}
}
}
};
// First round, mark variables whose vdevice is explicitly known
// (e.g. the output of R.hint_on_device), and propagate.
for (const auto& [var, vdevice] : explicit_annotations) {
to_visit.push_back(var);
known_vdevice.Set(var, vdevice);
}
propagate();
// Second round, mark variables whose vdevice is hinted at (e.g. the
// input of R.hint_on_device), and propagate.
for (const auto& [var, vdevice] : hint_on_device_args) {
if (!known_vdevice.count(var)) {
to_visit.push_back(var);
known_vdevice.Set(var, vdevice);
}
}
propagate();
return known_vdevice;
}
// Update the module to include the inferred VDevice annotations.
class VDeviceStructInfoUpdater : ExprMutator {
public:
static IRModule Apply(IRModule mod, ffi::Map<Var, VDevice> vdevice_map) {
VDeviceStructInfoUpdater mutator(VDeviceLookup(mod), vdevice_map);
IRModule updates;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto func = base_func.as<Function>()) {
auto updated = Downcast<Function>(mutator(func.value()));
if (!updated.same_as(base_func)) {
updates->Add(gvar, updated);
}
}
}
if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
return mod;
}
private:
VDeviceStructInfoUpdater(VDeviceLookup vdevice_lookup, ffi::Map<Var, VDevice> vdevice_map)
: vdevice_lookup_(vdevice_lookup), vdevice_map_(vdevice_map) {}
Var VisitVarDef(const Var& old_var) override {
auto var = ExprMutator::VisitVarDef(old_var);
if (auto tinfo = var->struct_info_.as<TensorStructInfoNode>()) {
if (auto opt = vdevice_map_.Get(old_var)) {
auto vdevice = opt.value();
TensorStructInfo new_sinfo = [&]() {
if (tinfo->shape.defined()) {
return TensorStructInfo(tinfo->shape.value(), tinfo->dtype, vdevice, tinfo->span);
} else {
return TensorStructInfo(tinfo->dtype, tinfo->ndim, vdevice, tinfo->span);
}
}();
if (var->IsInstance<DataflowVarNode>()) {
var = DataflowVar(var->vid, new_sinfo, var->span);
} else {
var = Var(var->vid, new_sinfo, var->span);
}
}
}
return var;
}
using ExprMutator::VisitExpr_;
Expr VisitExpr_(const CallNode* op) override {
auto call = Downcast<Call>(ExprMutator::VisitExpr_(op));
if (call->op != hint_on_device_op_) {
return call;
}
TVM_FFI_ICHECK_EQ(call->args.size(), 1);
auto arg = call->args[0];
auto input_vdevice = Downcast<TensorStructInfo>(arg->struct_info_)->vdevice;
auto output_vdevice = vdevice_lookup_(call->attrs);
if (input_vdevice.defined() && input_vdevice.value() == output_vdevice) {
return arg;
} else {
ffi::ObjectPtr<ToVDeviceAttrs> attrs = ffi::make_object<ToVDeviceAttrs>();
attrs->dst_vdevice = output_vdevice;
return Call(to_vdevice_op_, {arg}, Attrs(attrs), {});
}
}
VDeviceLookup vdevice_lookup_;
ffi::Map<Var, VDevice> vdevice_map_;
const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device");
const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
};
} // namespace
namespace transform {
Pass RealizeVDevice() {
auto pass_func = [=](IRModule mod, PassContext pc) {
auto known_vdevices = InferVDevice(mod);
return VDeviceStructInfoUpdater::Apply(mod, known_vdevices);
};
return CreateModulePass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"RealizeVDevice",
/*required=*/{});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.transform.RealizeVDevice", RealizeVDevice);
}
} // namespace transform
} // namespace relax
} // namespace tvm