blob: 7dddfaecbbe7ffbe16879de7a2f50bac824c8dbb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_PRINTER_RELAX_UTILS_H_
#define TVM_SCRIPT_PRINTER_RELAX_UTILS_H_
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/type.h>
#include <tvm/relax/utils.h>
#include <tvm/script/printer/ir_docsifier.h>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "../utils.h"
namespace tvm {
namespace script {
namespace printer {
class RelaxFrameNode : public FrameNode {
public:
bool is_func = false;
bool module_alias_printed = false;
std::unordered_set<const tir::VarNode*>* func_vars = nullptr;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<RelaxFrameNode>()
.def_ro("is_func", &RelaxFrameNode::is_func)
.def_ro("module_alias_printed", &RelaxFrameNode::module_alias_printed);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.RelaxFrame", RelaxFrameNode, FrameNode);
};
class RelaxFrame : public Frame {
public:
explicit RelaxFrame(const IRDocsifier& d) {
ObjectPtr<RelaxFrameNode> n = ffi::make_object<RelaxFrameNode>();
n->stmts.clear();
n->d = d.get();
n->is_func = false;
n->func_vars = nullptr;
data_ = std::move(n);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RelaxFrame, Frame, RelaxFrameNode);
};
/*! \brief Redirected method for the ReprPrinter */
inline std::string ReprPrintRelax(const ObjectRef& obj, const PrinterConfig& cfg) {
IRDocsifier d(cfg);
With<RelaxFrame> f(d);
(*f)->AddDispatchToken(d, "relax");
return Docsify(obj, d, *f, cfg);
}
inline IdDoc DefineVar(const relax::Var& var, const Frame& frame, const IRDocsifier& d) {
return d->Define(var, frame, var->name_hint().empty() ? "v" : var->name_hint());
}
inline ffi::Optional<ExprDoc> StructInfoAsAnn(const relax::Var& v, const AccessPath& v_p,
const IRDocsifier& d,
const ffi::Optional<relax::Expr>& rhs) {
if (!v->struct_info_.defined()) {
return std::nullopt;
}
bool attempt_to_hide_struct_info = !d->cfg->show_all_struct_info;
if (const auto* call = rhs.as<relax::CallNode>()) {
static const Op& call_tir_op = Op::Get("relax.call_tir");
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op)) {
attempt_to_hide_struct_info = true;
}
}
if (attempt_to_hide_struct_info) {
ffi::Optional<relax::StructInfo> inferred_sinfo = std::nullopt;
if (auto opt = rhs.as<relax::Call>()) {
auto call = opt.value();
if (auto opt = call->op.as<Op>()) {
auto op = opt.value();
static auto op_map_infer_struct_info =
Op::GetAttrMap<relax::FInferStructInfo>("FInferStructInfo");
auto temp_builder = relax::BlockBuilder::Create(std::nullopt);
inferred_sinfo = op_map_infer_struct_info[op](call, temp_builder);
} else if (auto opt = call->op.as<relax::FuncStructInfo>()) {
auto temp_builder = relax::BlockBuilder::Create(std::nullopt);
inferred_sinfo =
DeriveCallRetStructInfo(opt.value(), call, temp_builder, temp_builder->GetAnalyzer());
}
} else if (const auto* tuple = rhs.as<relax::TupleNode>()) {
inferred_sinfo = relax::TupleStructInfo(tuple->fields.Map(relax::GetStructInfo));
} else if (const auto* get_item = rhs.as<relax::TupleGetItemNode>()) {
if (auto ptr = get_item->tuple->struct_info_.as<relax::TupleStructInfoNode>();
ptr && get_item->index < static_cast<int>(ptr->fields.size())) {
inferred_sinfo = ptr->fields[get_item->index];
}
} else if (const auto* trivial_binding = rhs.as<relax::VarNode>()) {
inferred_sinfo = trivial_binding->struct_info_.as<relax::StructInfo>();
}
if (inferred_sinfo && StructuralEqual()(inferred_sinfo, v->struct_info_)) {
return std::nullopt;
}
}
return d->AsDoc<ExprDoc>(v->struct_info_, v_p->Attr("struct_info_"));
}
ffi::Array<StmtDoc> PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p,
const IRDocsifier& d, bool use_ret);
ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifier& d);
inline int FindVDeviceIndexByTargetKind(const VDevice& vdevice, const IRDocsifier& d) {
ffi::Array<GlobalInfo> vdevices = d->global_infos["vdevice"];
int kind_index = 0;
for (size_t i = 0; i < vdevices.size(); ++i) {
auto vdev = Downcast<VDevice>(vdevices[i]);
if (vdev.same_as(vdevice)) {
return kind_index;
}
if (vdev->target->kind->name == vdevice->target->kind->name) {
kind_index++;
}
}
return -1;
}
} // namespace printer
} // namespace script
} // namespace tvm
#endif // TVM_SCRIPT_PRINTER_RELAX_UTILS_H_