blob: e597df64501d869d315149f800bd8c507daea605 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/tir/stmt_functor.h>
#include "./utils.h"
namespace tvm {
namespace script {
namespace printer {
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::ObjectStructInfo>( //
"", [](relax::ObjectStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc {
return Relax(d, "Object");
});
ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifier& d) {
ExprDoc expr_doc = d->AsDoc<ExprDoc>(e, e_p);
// Step 1. Find if `func_vars` are being collected
const RelaxFrameNode* f = nullptr;
for (const Frame& frame : d->frames) {
if (const auto* relax_frame = frame.as<RelaxFrameNode>()) {
if (relax_frame->func_vars) {
f = relax_frame;
break;
}
}
}
// Step 2. Figure out if the PrimExpr contains at least a func var
bool func_var_mode = false;
if (f != nullptr) {
tir::PostOrderVisit(e, [f, &func_var_mode](const ObjectRef& obj) -> void {
if (const auto* var = obj.as<tir::VarNode>()) {
if (f->func_vars->count(var)) {
func_var_mode = true;
}
}
});
}
// Step 3. Stringify the PrimExpr if func var exists
if (func_var_mode) {
return LiteralDoc::Str(DocToPythonScript(expr_doc, d->cfg), e_p);
}
return expr_doc;
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::PrimStructInfo>(
"", [](relax::PrimStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc {
ffi::Array<ExprDoc, void> args;
ffi::Array<ffi::String> kwargs_keys;
ffi::Array<ExprDoc, void> kwargs_values;
if (n->value.defined()) {
kwargs_keys.push_back("value");
kwargs_values.push_back(PrintShapeVar(n->value.value(), n_p->Attr("value"), d));
} else {
args.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype")));
}
return Relax(d, "Prim")->Call(args, kwargs_keys, kwargs_values);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::ShapeStructInfo>(
"", [](relax::ShapeStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc {
if (n->values.defined()) {
ffi::Array<PrimExpr> shape = n->values.value();
AccessPath shape_p = n_p->Attr("values");
ffi::Array<ExprDoc> shape_docs;
for (int i = 0, ndim = shape.size(); i < ndim; ++i) {
shape_docs.push_back(PrintShapeVar(shape[i], shape_p->ArrayItem(i), d));
}
return Relax(d, "Shape")->Call({ListDoc(shape_docs)});
}
return Relax(d, "Shape")
->Call({}, {"ndim"}, {LiteralDoc::Int(n->ndim, n_p->Attr("ndim"))});
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::TensorStructInfo>( //
"", [](relax::TensorStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc {
ffi::Array<ExprDoc> args;
ffi::Array<ffi::String> kwargs_keys;
ffi::Array<ExprDoc> kwargs_values;
if (n->shape.defined()) {
// Need to dig into ShapeExpr to preserve the `R.shape` prefix
if (const auto* shape = n->shape.value().as<relax::ShapeExprNode>()) {
auto shape_expr = ffi::GetRef<relax::ShapeExpr>(shape);
AccessPath shape_p = n_p->Attr("shape")->Attr("values");
ffi::Array<ExprDoc> shape_docs;
for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) {
shape_docs.push_back(
PrintShapeVar(shape_expr->values[i], shape_p->ArrayItem(i), d));
}
args.push_back(TupleDoc(shape_docs));
} else {
args.push_back(d->AsDoc<ExprDoc>(n->shape.value(), n_p->Attr("shape")));
}
}
if (!n->IsUnknownDtype()) {
kwargs_keys.push_back("dtype");
kwargs_values.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype")));
}
if (!n->shape.defined() && !n->IsUnknownNdim()) {
kwargs_keys.push_back("ndim");
kwargs_values.push_back(LiteralDoc::Int(n->ndim, n_p->Attr("ndim")));
}
if (n->vdevice.defined() && n->vdevice.value()->target.defined()) {
kwargs_keys.push_back("vdevice");
std::string dev_kind = n->vdevice.value()->target->kind->name;
int dev_index = FindVDeviceIndexByTargetKind(n->vdevice.value(), d);
kwargs_values.push_back(LiteralDoc::Str(
dev_kind + ":" + std::to_string(dev_index) + ":" + n->vdevice.value()->memory_scope,
n_p->Attr("vdevice")));
}
if (args.empty() && kwargs_keys.empty()) {
return Relax(d, "Tensor");
}
return Relax(d, "Tensor")->Call(args, kwargs_keys, kwargs_values);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::TupleStructInfo>( //
"", [](relax::TupleStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc {
if (n->fields.empty()) {
return Relax(d, "Tuple");
}
ffi::Array<ExprDoc> fields_doc;
AccessPath fields_p = n_p->Attr("fields");
for (int i = 0, l = n->fields.size(); i < l; ++i) {
fields_doc.push_back(d->AsDoc<ExprDoc>(n->fields[i], fields_p->ArrayItem(i)));
}
return Relax(d, "Tuple")->Call(fields_doc);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::FuncStructInfo>( //
"", [](relax::FuncStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc {
auto ret_doc = d->AsDoc<ExprDoc>(n->ret, n_p->Attr("ret"));
auto purity_doc = LiteralDoc::Boolean(n->purity, n_p->Attr("purity"));
if (n->IsOpaque()) {
ffi::Array<ffi::String> keys;
ffi::Array<ExprDoc, void> values;
if (!n->ret->IsInstance<relax::ObjectStructInfoNode>()) {
keys.push_back("ret");
values.push_back(ret_doc);
}
if (n->purity) {
keys.push_back("purity");
values.push_back(purity_doc);
}
if (keys.size()) {
return Relax(d, "Callable")->Call({}, keys, values);
} else {
return Relax(d, "Callable");
}
}
// TODO(@junrushao): track symbolic shape relation
ffi::Array<ExprDoc> params_doc;
ffi::Array<relax::StructInfo> params = n->params.value();
AccessPath params_p = n_p->Attr("params");
for (int i = 0, n_params = params.size(); i < n_params; ++i) {
params_doc.push_back(d->AsDoc<ExprDoc>(params[i], params_p->ArrayItem(i)));
}
return Relax(d, "Callable")->Call({TupleDoc(params_doc), ret_doc, purity_doc});
});
TVM_SCRIPT_REPR(relax::ObjectStructInfoNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::PrimStructInfoNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::ShapeStructInfoNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::TensorStructInfoNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::TupleStructInfoNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::FuncStructInfoNode, ReprPrintRelax);
} // namespace printer
} // namespace script
} // namespace tvm