blob: 658e76be466c57ec1ed464ed78f347960f6ab275 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/module.h>
#include <tvm/script/ir_builder/base.h>
namespace tvm {
namespace script {
namespace ir_builder {
TVM_FFI_STATIC_INIT_BLOCK() {
IRBuilderFrameNode::RegisterReflection();
IRBuilderNode::RegisterReflection();
}
void IRBuilderFrameNode::EnterWithScope() {
IRBuilder::Current()->frames.push_back(ffi::GetRef<IRBuilderFrame>(this));
}
void IRBuilderFrameNode::ExitWithScope() {
for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) {
(*it)();
}
this->callbacks.clear();
IRBuilder::Current()->frames.pop_back();
}
void IRBuilderFrameNode::AddCallback(ffi::TypedFunction<void()> callback) {
if (IRBuilder::Current()->frames.empty()) {
LOG(FATAL) << "ValueError: No frames in Builder to add callback";
}
IRBuilder::Current()->frames.back()->callbacks.push_back(callback);
}
IRBuilder::IRBuilder() {
ObjectPtr<IRBuilderNode> n = ffi::make_object<IRBuilderNode>();
n->frames.clear();
n->result = std::nullopt;
data_ = n;
}
std::vector<IRBuilder>* ThreadLocalBuilderStack() {
thread_local std::vector<IRBuilder> stack;
return &stack;
}
void IRBuilder::EnterWithScope() {
IRBuilderNode* n = this->get();
CHECK(n->frames.empty()) << "ValueError: There are frame(s) left in the builder: "
<< n->frames.size()
<< ". Please use a fresh new builder every time building IRs";
n->result = std::nullopt;
std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
stack->push_back(*this);
}
void IRBuilder::ExitWithScope() {
std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
ICHECK(!stack->empty());
stack->pop_back();
}
IRBuilder IRBuilder::Current() {
std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
CHECK(!stack->empty()) << "ValueError: No builder in current scope";
return stack->back();
}
bool IRBuilder::IsInScope() {
std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
return !stack->empty();
}
namespace details {
Namer::FType& Namer::vtable() {
static FType inst;
return inst;
}
void Namer::Name(ObjectRef node, ffi::String name) {
static const FType& f = vtable();
CHECK(node.defined()) << "ValueError: Cannot name nullptr with: " << name;
CHECK(f.can_dispatch(node)) << "ValueError: Do not know how to name type \""
<< node->GetTypeKey();
f(node, name);
}
} // namespace details
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_method("script.ir_builder.IRBuilderFrameEnter", &IRBuilderFrameNode::EnterWithScope)
.def_method("script.ir_builder.IRBuilderFrameExit", &IRBuilderFrameNode::ExitWithScope)
.def_method("script.ir_builder.IRBuilderFrameAddCallback", &IRBuilderFrameNode::AddCallback)
.def("script.ir_builder.IRBuilder", []() { return IRBuilder(); })
.def_method("script.ir_builder.IRBuilderEnter", &IRBuilder::EnterWithScope)
.def_method("script.ir_builder.IRBuilderExit", &IRBuilder::ExitWithScope)
.def("script.ir_builder.IRBuilderCurrent", IRBuilder::Current)
.def("script.ir_builder.IRBuilderIsInScope", IRBuilder::IsInScope)
.def_method("script.ir_builder.IRBuilderGet", &IRBuilderNode::Get<ObjectRef>)
.def("script.ir_builder.IRBuilderName", IRBuilder::Name<ObjectRef>);
}
} // namespace ir_builder
} // namespace script
} // namespace tvm