blob: 6d96327e2db48c79b21c61f1ad39a5decec488d7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/distributed/struct_info.h>
#include "./utils.h"
namespace tvm {
namespace script {
namespace printer {
class AttrPrinter {
public:
explicit AttrPrinter(AccessPath p, const IRDocsifier& d, ffi::Array<ffi::String>* keys,
ffi::Array<ExprDoc>* values)
: p(std::move(p)), d(d), keys(keys), values(values) {}
void operator()(const tvm::Attrs& attrs) {
if (const auto* dict_attrs = attrs.as<DictAttrsNode>()) {
for (const auto& [key, value] : dict_attrs->dict) {
keys->push_back(key);
values->push_back(d->AsDoc<ExprDoc>(value, p->Attr(key)));
}
} else {
const TVMFFITypeInfo* attrs_tinfo = TVMFFIGetTypeInfo(attrs->type_index());
ICHECK(attrs_tinfo->metadata != nullptr)
<< "Object `" << attrs->GetTypeKey()
<< "` misses reflection registration and do not support serialization";
// new printing mechanism using the new reflection
ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const TVMFFIFieldInfo* field_info) {
ffi::String field_name = ffi::String(field_info->name);
Any field_value = ffi::reflection::FieldGetter(field_info)(attrs);
keys->push_back(field_name);
values->push_back(d->AsDoc<ExprDoc>(field_value, p->Attr(field_name)));
});
}
}
AccessPath p;
const IRDocsifier& d;
ffi::Array<ffi::String>* keys;
ffi::Array<ExprDoc>* values;
};
ExprDoc PrintCallee(const relax::Expr& n, const AccessPath& n_p, const IRDocsifier& d) {
// TODO(@junrushao): handle callee better
if (const auto* ext = n.as<relax::ExternFuncNode>()) {
return LiteralDoc::Str(ext->global_symbol, n_p);
} else {
return d->AsDoc<ExprDoc>(n, n_p);
}
}
ffi::Optional<ExprDoc> PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& n_p,
const IRDocsifier& d) {
static const Op& call_tir_op = Op::Get("relax.call_tir");
static const Op& call_tir_inplace_op = Op::Get("relax.call_tir_inplace");
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
static const Op& call_tir_with_grad_op = Op::Get("relax.call_tir_with_grad");
static const Op& call_tir_local_view = Op::Get("relax.dist.call_tir_local_view");
if (!n->op.same_as(call_tir_op) && !n->op.same_as(call_dps_packed_op) &&
!n->op.same_as(call_tir_with_grad_op) && !n->op.same_as(call_tir_local_view) &&
!n->op.same_as(call_tir_inplace_op)) {
return std::nullopt;
}
ICHECK(n->args.size() == 2 || n->args.size() == 3);
ICHECK(n->sinfo_args.size() == 1);
ffi::Array<ExprDoc> args;
ffi::Array<ffi::String> kwargs_keys;
ffi::Array<ExprDoc> kwargs_values;
// Step 1. Print n->args[0], the callee
args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d));
// Step 2. Print n->args[1], the input arguments
args.push_back(d->AsDoc<ExprDoc>(n->args[1], n_p->Attr("args")->ArrayItem(1)));
// Step 3. Print n->sinfo_args, the output struct info
relax::StructInfo o_sinfo = n->sinfo_args[0];
AccessPath o_sinfo_p = n_p->Attr("sinfo_args")->ArrayItem(0);
bool is_dtensor = false;
kwargs_keys.push_back("out_sinfo");
if (const auto* o = o_sinfo.as<relax::TupleStructInfoNode>()) {
ffi::Array<ExprDoc> fields;
AccessPath fields_p = o_sinfo_p->Attr("fields");
for (int i = 0, l = o->fields.size(); i < l; ++i) {
if (o->fields[i].as<relax::distributed::DTensorStructInfoNode>()) {
is_dtensor = true;
}
fields.push_back(d->AsDoc<ExprDoc>(o->fields[i], fields_p->ArrayItem(i)));
}
kwargs_values.push_back(ListDoc(fields));
} else {
if (o_sinfo.as<relax::distributed::DTensorStructInfoNode>()) {
is_dtensor = true;
}
kwargs_values.push_back(d->AsDoc<ExprDoc>(o_sinfo, o_sinfo_p));
}
// for call_tir_inplace, we also need to include the inplace args
if (n->op.same_as(call_tir_inplace_op)) {
kwargs_keys.push_back("inplace_indices");
ffi::Array<ExprDoc> index_fields;
if (auto* call_tir_inplace_attrs = n->attrs.as<relax::CallTIRInplaceAttrs>()) {
for (auto inplace_index : call_tir_inplace_attrs->inplace_indices) {
index_fields.push_back(
LiteralDoc::Int(inplace_index.IntValue(), n_p->Attr("attrs")->Attr("inplace_indices")));
}
}
kwargs_values.push_back(ListDoc(index_fields));
}
// start of specially handling call_tir_with_grad
if (const auto* call_tir_with_grad_attrs = n->attrs.as<relax::CallTIRWithGradAttrs>()) {
kwargs_keys.push_back("te_grad_name");
kwargs_values.push_back(LiteralDoc::Str(call_tir_with_grad_attrs->te_grad_name,
n_p->Attr("attrs")->Attr("te_grad_name")));
if (!call_tir_with_grad_attrs->te_grad_kwargs.empty()) {
kwargs_keys.push_back("te_grad_kwargs");
kwargs_values.push_back(d->AsDoc<ExprDoc>(call_tir_with_grad_attrs->te_grad_kwargs,
n_p->Attr("attrs")->Attr("te_grad_kwargs")));
}
}
if (n->op.same_as(call_tir_with_grad_op)) {
return Relax(d, "call_tir_with_grad")->Call(args, kwargs_keys, kwargs_values);
}
// end of specially handling call_tir_with_grad
if (n->op.same_as(call_dps_packed_op)) {
return Relax(d, "call_dps_packed")->Call(args, kwargs_keys, kwargs_values);
}
// Step 4. Print n->args[2], the tir variables
if (n->args.size() == 3) {
kwargs_keys.push_back("tir_vars");
kwargs_values.push_back(d->AsDoc<ExprDoc>(n->args[2], n_p->Attr("args")->ArrayItem(2)));
}
if (n->op.same_as(call_tir_local_view)) {
return Relax(d, "dist.call_tir_local_view")->Call(args, kwargs_keys, kwargs_values);
} else if (is_dtensor) {
return Relax(d, "dist.call_tir")->Call(args, kwargs_keys, kwargs_values);
} else if (n->op.same_as(call_tir_inplace_op)) {
return Relax(d, "call_tir_inplace")->Call(args, kwargs_keys, kwargs_values);
} else {
return Relax(d, "call_tir")->Call(args, kwargs_keys, kwargs_values);
}
}
ffi::Optional<ExprDoc> PrintAssertOp(const relax::Call& n, const AccessPath& n_p,
const IRDocsifier& d) {
static const Op& assert_op = Op::Get("relax.assert_op");
if (!n->op.same_as(assert_op)) {
return std::nullopt;
}
ICHECK(n->args.size() >= 2);
// special handling: it is important to indicate that the format string (second argument)
// is the _format_ string, or else roundtripping will fail
// (the format string will be interpreted as an argument and there will be a new default format
// string given)
ffi::Array<ExprDoc> args;
args.push_back(d->AsDoc<ExprDoc>(n->args[0], n_p->Attr("args")->ArrayItem(0)));
ExprDoc second_arg = d->AsDoc<ExprDoc>(n->args[1], n_p->Attr("args")->ArrayItem(1));
for (size_t i = 2; i < n->args.size(); i++) {
args.push_back(d->AsDoc<ExprDoc>(n->args[i], n_p->Attr("args")->ArrayItem(i)));
}
return Relax(d, "assert_op")->Call(args, {"format"}, {second_arg});
}
ffi::Optional<ExprDoc> PrintHintOnDevice(const relax::Call& n, const AccessPath& n_p,
const IRDocsifier& d) {
static const Op& hint_on_device_op = Op::Get("relax.hint_on_device");
if (!n->op.same_as(hint_on_device_op)) {
return std::nullopt;
}
ffi::Array<ExprDoc> args;
args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d));
ffi::Array<ffi::String> kwargs_keys;
ffi::Array<ExprDoc> kwargs_values;
ICHECK(n->attrs.defined());
if (n->attrs.as<relax::HintOnDeviceAttrs>()) {
AttrPrinter(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values)(n->attrs);
ExprDoc scope_val = kwargs_values.back();
kwargs_keys.pop_back();
kwargs_values.pop_back();
args.push_back(Relax(d, "device")->Call({}, kwargs_keys, kwargs_values));
args.push_back(scope_val);
}
return Relax(d, "hint_on_device")->Call(args);
}
ffi::Optional<ExprDoc> PrintToVDevice(const relax::Call& n, const AccessPath& n_p,
const IRDocsifier& d) {
static const Op& to_vdevice_op = Op::Get("relax.to_vdevice");
if (!n->op.same_as(to_vdevice_op)) {
return std::nullopt;
}
ffi::Array<ExprDoc> args;
args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d));
ffi::Array<ffi::String> kwargs_keys;
ffi::Array<ExprDoc> kwargs_values;
ICHECK(n->attrs.defined());
if (const auto* attrs = n->attrs.as<relax::ToVDeviceAttrs>()) {
VDevice vdev = attrs->dst_vdevice;
std::string dev_kind = vdev->target->kind->name;
int dev_index = FindVDeviceIndexByTargetKind(vdev, d);
kwargs_keys.push_back("dst_vdevice");
kwargs_values.push_back(
LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index) + ":" + vdev->memory_scope,
n_p->Attr("dst_vdevice")));
}
return Relax(d, "to_vdevice")->Call(args, kwargs_keys, kwargs_values);
}
ffi::Optional<ExprDoc> PrintRelaxPrint(const relax::Call& n, const AccessPath& n_p,
const IRDocsifier& d) {
static const Op& print_op = Op::Get("relax.print");
if (!n->op.same_as(print_op)) {
return std::nullopt;
}
ICHECK(n->args.size() >= 1);
// special handling: it is important to indicate that the format string (first argument)
// is the _format_ string, or else roundtripping will fail
// (the format string will be interpreted as an argument and there will be a new default format
// string given)
ExprDoc first_arg = d->AsDoc<ExprDoc>(n->args[0], n_p->Attr("args")->ArrayItem(0));
ffi::Array<ExprDoc> args;
for (size_t i = 1; i < n->args.size(); i++) {
args.push_back(d->AsDoc<ExprDoc>(n->args[i], n_p->Attr("args")->ArrayItem(i)));
}
return Relax(d, "print")->Call(args, {"format"}, {first_arg});
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::Call>( //
"", [](relax::Call n, AccessPath n_p, IRDocsifier d) -> Doc {
// Special case: call_tir, call_dps_packed, call_tir_with_grad
if (ffi::Optional<ExprDoc> doc = PrintCallTIRDPSPacked(n, n_p, d)) {
return doc.value();
}
// Special case: assert_op
if (ffi::Optional<ExprDoc> doc = PrintAssertOp(n, n_p, d)) {
return doc.value();
}
// Special case: hint_on_device
if (ffi::Optional<ExprDoc> doc = PrintHintOnDevice(n, n_p, d)) {
return doc.value();
}
// Special case: to_vdevice
if (ffi::Optional<ExprDoc> doc = PrintToVDevice(n, n_p, d)) {
return doc.value();
}
// Special case: print
if (ffi::Optional<ExprDoc> doc = PrintRelaxPrint(n, n_p, d)) {
return doc.value();
}
ExprDoc prefix{ffi::UnsafeInit()};
ffi::Array<ExprDoc> args;
ffi::Array<ffi::String> kwargs_keys;
ffi::Array<ExprDoc> kwargs_values;
// Step 1. Print op
if (const auto* op = n->op.as<relax::ExternFuncNode>()) {
prefix = Relax(d, "call_packed");
args.push_back(LiteralDoc::Str(op->global_symbol, n_p->Attr("op")));
} else if (const auto* op = n->op.as<tvm::OpNode>()) {
std::string name = op->name;
if (name.rfind("relax.", 0) == 0) {
prefix = Relax(d, name.substr(6));
} else {
prefix = IdDoc(name);
}
prefix->source_paths.push_back(n_p->Attr("op"));
} else if (n->op->IsInstance<relax::VarNode>() ||
n->op->IsInstance<tvm::GlobalVarNode>()) {
prefix = d->AsDoc<ExprDoc>(n->op, n_p->Attr("op"));
} else {
LOG(FATAL) << "TypeError: Unsupported op: " << n->op->GetTypeKey();
}
// Step 2. Print args
if (!n->args.empty()) {
args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d));
}
for (int i = 1, l = n->args.size(); i < l; ++i) {
args.push_back(d->AsDoc<ExprDoc>(n->args[i], n_p->Attr("args")->ArrayItem(i)));
}
// Step 3. Print attrs
if (n->attrs.defined()) {
if (n->op->IsInstance<relax::ExternFuncNode>()) {
kwargs_keys.push_back("attrs_type_key");
kwargs_values.push_back(LiteralDoc::Str(n->attrs->GetTypeKey(), n_p->Attr("attrs")));
}
if (const auto* attrs = n->attrs.as<tvm::DictAttrsNode>()) {
std::vector<std::pair<ffi::String, ffi::Any>> sorted;
for (const auto& kv : attrs->dict) {
sorted.push_back(kv);
}
std::sort(sorted.begin(), sorted.end(),
[](const auto& a, const auto& b) { return a.first < b.first; });
for (const auto& kv : sorted) {
kwargs_keys.push_back(kv.first);
kwargs_values.push_back(
d->AsDoc<ExprDoc>(kv.second, n_p->Attr("attrs")->Attr(kv.first)));
}
} else {
AttrPrinter(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values)(n->attrs);
}
}
// Step 4. Print type_args
if (n->sinfo_args.size() > 0) {
AccessPath sinfo_args_p = n_p->Attr("sinfo_args");
ffi::Array<ExprDoc> sinfo_args;
for (int i = 0, l = n->sinfo_args.size(); i < l; ++i) {
sinfo_args.push_back(d->AsDoc<ExprDoc>(n->sinfo_args[i], sinfo_args_p->ArrayItem(i)));
}
kwargs_keys.push_back("sinfo_args");
kwargs_values.push_back(TupleDoc(sinfo_args));
}
return prefix->Call(args, kwargs_keys, kwargs_values);
});
TVM_SCRIPT_REPR(relax::CallNode, ReprPrintRelax);
} // namespace printer
} // namespace script
} // namespace tvm