blob: 36c61d78b345cd6c3bb9c21c73f256545db2481b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/node/repr_printer.h>
#include <tvm/node/script_printer.h>
#include <algorithm>
namespace tvm {
using AccessPath = ffi::reflection::AccessPath;
TVM_FFI_STATIC_INIT_BLOCK() { PrinterConfigNode::RegisterReflection(); }
TVMScriptPrinter::FType& TVMScriptPrinter::vtable() {
static FType inst;
return inst;
}
std::string TVMScriptPrinter::Script(const ObjectRef& node,
const ffi::Optional<PrinterConfig>& cfg) {
if (!TVMScriptPrinter::vtable().can_dispatch(node)) {
std::ostringstream os;
ReprPrinter printer(os);
printer.Print(node);
return os.str();
}
return TVMScriptPrinter::vtable()(node, cfg.value_or(PrinterConfig()));
}
bool IsIdentifier(const std::string& name) {
// Python identifiers follow the regex: "^[a-zA-Z_][a-zA-Z0-9_]*$"
// `std::regex` would cause a symbol conflict with PyTorch, we avoids to use it in the codebase.
//
// We convert the regex into following conditions:
// 1. The name is not empty.
// 2. The first character is either an alphabet or an underscore.
// 3. The rest of the characters are either an alphabet, a digit or an underscore.
return name.size() > 0 && //
(std::isalpha(name[0]) || name[0] == '_') && //
std::all_of(name.begin() + 1, name.end(),
[](char c) { return std::isalnum(c) || c == '_'; });
}
PrinterConfig::PrinterConfig(ffi::Map<ffi::String, Any> config_dict) {
runtime::ObjectPtr<PrinterConfigNode> n = ffi::make_object<PrinterConfigNode>();
if (auto v = config_dict.Get("name")) {
n->binding_names.push_back(Downcast<ffi::String>(v.value()));
}
if (auto v = config_dict.Get("show_meta")) {
n->show_meta = v.value().cast<bool>();
}
if (auto v = config_dict.Get("ir_prefix")) {
n->ir_prefix = Downcast<ffi::String>(v.value());
}
if (auto v = config_dict.Get("tir_prefix")) {
n->tir_prefix = Downcast<ffi::String>(v.value());
}
if (auto v = config_dict.Get("relax_prefix")) {
n->relax_prefix = Downcast<ffi::String>(v.value());
}
if (auto v = config_dict.Get("module_alias")) {
n->module_alias = Downcast<ffi::String>(v.value());
}
if (auto v = config_dict.Get("buffer_dtype")) {
n->buffer_dtype = DataType(ffi::StringToDLDataType(Downcast<ffi::String>(v.value())));
}
if (auto v = config_dict.Get("int_dtype")) {
n->int_dtype = DataType(ffi::StringToDLDataType(Downcast<ffi::String>(v.value())));
}
if (auto v = config_dict.Get("float_dtype")) {
n->float_dtype = DataType(ffi::StringToDLDataType(Downcast<ffi::String>(v.value())));
}
if (auto v = config_dict.Get("verbose_expr")) {
n->verbose_expr = v.value().cast<bool>();
}
if (auto v = config_dict.Get("indent_spaces")) {
n->indent_spaces = v.value().cast<int>();
}
if (auto v = config_dict.Get("print_line_numbers")) {
n->print_line_numbers = v.value().cast<bool>();
}
if (auto v = config_dict.Get("num_context_lines")) {
n->num_context_lines = v.value().cast<int>();
}
if (auto v = config_dict.Get("path_to_underline")) {
n->path_to_underline =
Downcast<ffi::Optional<ffi::Array<AccessPath>>>(v).value_or(ffi::Array<AccessPath>());
}
if (auto v = config_dict.Get("path_to_annotate")) {
n->path_to_annotate = Downcast<ffi::Optional<ffi::Map<AccessPath, ffi::String>>>(v).value_or(
ffi::Map<AccessPath, ffi::String>());
}
if (auto v = config_dict.Get("obj_to_underline")) {
n->obj_to_underline =
Downcast<ffi::Optional<ffi::Array<ObjectRef>>>(v).value_or(ffi::Array<ObjectRef>());
}
if (auto v = config_dict.Get("obj_to_annotate")) {
n->obj_to_annotate = Downcast<ffi::Optional<ffi::Map<ObjectRef, ffi::String>>>(v).value_or(
ffi::Map<ObjectRef, ffi::String>());
}
if (auto v = config_dict.Get("syntax_sugar")) {
n->syntax_sugar = v.value().cast<bool>();
}
if (auto v = config_dict.Get("show_object_address")) {
n->show_object_address = v.value().cast<bool>();
}
if (auto v = config_dict.Get("show_all_struct_info")) {
n->show_all_struct_info = v.value().cast<bool>();
}
// Checking prefixes if they are valid Python identifiers.
CHECK(IsIdentifier(n->ir_prefix)) << "Invalid `ir_prefix`: " << n->ir_prefix;
CHECK(IsIdentifier(n->tir_prefix)) << "Invalid `tir_prefix`: " << n->tir_prefix;
CHECK(IsIdentifier(n->relax_prefix)) << "Invalid `relax_prefix`: " << n->relax_prefix;
CHECK(n->module_alias.empty() || IsIdentifier(n->module_alias))
<< "Invalid `module_alias`: " << n->module_alias;
this->data_ = std::move(n);
}
ffi::Array<ffi::String> PrinterConfigNode::GetBuiltinKeywords() {
ffi::Array<ffi::String> result{this->ir_prefix, this->tir_prefix, this->relax_prefix};
if (!this->module_alias.empty()) {
result.push_back(this->module_alias);
}
return result;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("node.PrinterConfig",
[](ffi::Map<ffi::String, Any> config_dict) { return PrinterConfig(config_dict); })
.def("node.TVMScriptPrinterScript", TVMScriptPrinter::Script);
}
} // namespace tvm