blob: e609f1b8efd2763eec54380ca429d11f38eab58a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/module.h>
#include <tvm/relax/analysis.h>
#include <tvm/script/ir_builder/ir/ir.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include "./utils.h"
namespace tvm {
namespace script {
namespace ir_builder {
namespace ir {
IRModuleFrame IRModule() {
ObjectPtr<IRModuleFrameNode> n = ffi::make_object<IRModuleFrameNode>();
n->global_var_map.clear();
n->functions.clear();
return IRModuleFrame(n);
}
inline relax::StructInfo GetGlobalVarStructInfo(const BaseFunc& func) {
if (func->struct_info_.defined()) {
return tvm::relax::GetStructInfo(func);
} else if (const auto* prim_func = func.as<tvm::tir::PrimFuncNode>()) {
return tvm::relax::FuncStructInfo::OpaqueFunc(
tvm::relax::StructInfoFromType(prim_func->ret_type));
} else {
LOG(FATAL) << "Unsupported function type: " << func->GetTypeKey();
}
}
GlobalVar DeclFunction(const ffi::String& func_name, const BaseFunc& func_signature) {
IRModuleFrame frame = FindModuleFrame();
CHECK(!frame->global_var_map.count(func_name))
<< "ValueError: function " << func_name << " already exists";
auto gvar_type = [&]() -> Type {
if (auto prim_func = func_signature.as<tir::PrimFuncNode>()) {
ffi::Array<Type> arg_types =
prim_func->params.Map([](const auto& var) { return GetType(var); });
return FuncType(arg_types, prim_func->ret_type);
}
return {};
}();
GlobalVar gv = GlobalVar(func_name);
gv->struct_info_ = GetGlobalVarStructInfo(func_signature);
CHECK(frame->functions.find(gv) == frame->functions.end())
<< "ValueError: function " << func_name << " has already been defined.";
frame->global_var_map.Set(func_name, gv);
frame->functions.Set(gv, func_signature);
return gv;
}
void DefFunction(const ffi::String& func_name, const BaseFunc& func) {
IRModuleFrame frame = FindModuleFrame();
auto it = frame->global_var_map.find(func_name);
CHECK(it != frame->global_var_map.end())
<< "ValueError: function " << func_name << " does not exist, please declare it first.";
const GlobalVar& gv = (*it).second;
frame->functions.Set(gv, func);
gv->struct_info_ = GetGlobalVarStructInfo(func);
}
void ModuleAttrs(ffi::Map<ffi::String, Any> attrs, bool allow_overwrite) {
if (IRBuilder::IsInScope()) {
// TODO(hongyi): add comments to explain why we need to check if the module frame is in scope
IRModuleFrame frame = FindModuleFrame("I.ModuleAttr");
if (!allow_overwrite && !frame->attrs.empty()) {
LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n" << frame->attrs;
}
frame->attrs = attrs;
}
}
ffi::Optional<ObjectRef> ModuleGetAttr(const ffi::String& key) {
if (IRBuilder::IsInScope()) {
IRModuleFrame frame = FindModuleFrame();
if (frame->attrs.find(key) != frame->attrs.end()) {
return frame->attrs[key].cast<ObjectRef>();
}
}
return std::nullopt;
}
void ModuleSetAttr(const ffi::String& key, const ffi::Optional<ObjectRef>& value,
bool allow_override) {
if (IRBuilder::IsInScope()) {
IRModuleFrame frame = FindModuleFrame();
if (!allow_override && frame->attrs.find(key) != frame->attrs.end() && value.defined()) {
LOG(FATAL) << "ValueError: Duplicate module attr " << key;
}
if (value.defined()) {
frame->attrs.Set(key, value.value());
} else {
frame->attrs.erase(key);
}
} else {
LOG(FATAL) << "ValueError: Currently in in the scope of a module.";
}
}
void ModuleGlobalInfos(ffi::Map<ffi::String, ffi::Array<GlobalInfo>> global_infos) {
if (IRBuilder::IsInScope()) {
IRModuleFrame frame = FindModuleFrame("I.ModuleGlobalInfos");
if (!frame->global_infos.empty()) {
LOG(FATAL) << "ValueError: Duplicate module global_infos, previous one is:\n"
<< frame->global_infos;
}
frame->global_infos = global_infos;
}
}
VDevice LookupVDevice(ffi::String target_kind, int device_index) {
if (IRBuilder::IsInScope()) {
IRModuleFrame frame = FindModuleFrame();
if (frame->global_infos.empty()) {
LOG(FATAL) << "ValueError: The GlobalInfos in the IRModule is not defined.";
}
ffi::Array<GlobalInfo> vdevices = frame->global_infos["vdevice"];
if (vdevices.empty() || device_index < 0 ||
static_cast<size_t>(device_index) >= vdevices.size()) {
LOG(FATAL) << "ValueError: The target VDevice in the GlobalInfos was not found.";
}
if (target_kind == "vdevice") {
return Downcast<VDevice>(vdevices[device_index]);
}
int count = 0;
for (auto vdevice : vdevices) {
auto vdev = Downcast<VDevice>(vdevice);
if (vdev->target->kind->name == target_kind) {
if (count == device_index) {
return vdev;
}
count++;
}
}
}
LOG(WARNING) << "The annotated device was not found, please check your vdevice list.";
return VDevice();
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("script.ir_builder.ir.IRModule", IRModule)
.def("script.ir_builder.ir.DeclFunction", DeclFunction)
.def("script.ir_builder.ir.DefFunction", DefFunction)
.def("script.ir_builder.ir.ModuleAttrs", ModuleAttrs)
.def("script.ir_builder.ir.ModuleGetAttr", ModuleGetAttr)
.def("script.ir_builder.ir.ModuleSetAttr", ModuleSetAttr)
.def("script.ir_builder.ir.ModuleGlobalInfos", ModuleGlobalInfos)
.def("script.ir_builder.ir.LookupVDevice", LookupVDevice);
}
} // namespace ir
} // namespace ir_builder
} // namespace script
} // namespace tvm