| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/relax/analysis.h> |
| #include <tvm/relax/struct_info.h> |
| #include <tvm/script/ir_builder/relax/ir.h> |
| #include <tvm/tir/op.h> |
| |
| #include "./utils.h" |
| |
| namespace tvm { |
| namespace script { |
| namespace ir_builder { |
| namespace relax { |
| |
| ///////////////////////////////// Vars ////////////////////////////////// |
| |
| using tvm::script::ir_builder::details::Namer; |
| |
| TVM_STATIC_IR_FUNCTOR(Namer, vtable) |
| .set_dispatch<tvm::relax::VarNode>([](const ObjectRef& node, ffi::String name) -> void { |
| using tvm::relax::VarNode; |
| using tvm::relax::IdNode; |
| const VarNode* var = node.as<VarNode>(); |
| IdNode* vid = const_cast<IdNode*>(var->vid.get()); |
| vid->name_hint = name; |
| }); |
| |
| TVM_STATIC_IR_FUNCTOR(Namer, vtable) |
| .set_dispatch<tvm::relax::DataflowVarNode>([](const ObjectRef& node, ffi::String name) -> void { |
| using tvm::relax::DataflowVarNode; |
| using tvm::relax::IdNode; |
| const DataflowVarNode* var = node.as<DataflowVarNode>(); |
| IdNode* vid = const_cast<IdNode*>(var->vid.get()); |
| vid->name_hint = name; |
| }); |
| |
| /////////////////////////////// Function //////////////////////////////// |
| |
| FunctionFrame Function(const Bool& is_pure, const Bool& is_private) { |
| ObjectPtr<FunctionFrameNode> n = ffi::make_object<FunctionFrameNode>(); |
| const IRBuilder& ir_builder = IRBuilder::Current(); |
| ffi::Optional<tvm::IRModule> mod = std::nullopt; |
| if (const ffi::Optional<ir::IRModuleFrame> mod_frame = |
| ir_builder->GetLastFrame<ir::IRModuleFrame>()) { |
| mod = tvm::IRModule(mod_frame.value()->functions); |
| } |
| n->block_builder = tvm::relax::BlockBuilder::Create( |
| /*mod=*/mod, tvm::relax::BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript()); |
| n->is_pure = is_pure; |
| n->is_private = is_private; |
| return FunctionFrame(n); |
| } |
| |
| tvm::relax::Var Arg(const ffi::String& name, const tvm::relax::StructInfo& struct_info) { |
| FunctionFrame frame = FindFunctionFrame("R.Arg"); |
| tvm::relax::Var var(name, struct_info); |
| frame->params.push_back(var); |
| frame->block_builder->AddDefinitionToScope(var); |
| |
| return var; |
| } |
| |
| void FuncName(const ffi::String& name) { |
| FunctionFrame frame = FindFunctionFrame("R.func_name"); |
| if (frame->name.has_value()) { |
| LOG(FATAL) << "ValueError: Duplicate function name, previous one is: \"" << frame->name.value() |
| << "\""; |
| } |
| frame->name = name; |
| } |
| |
| void FuncAttrs(ffi::Map<ffi::String, ffi::Any> attrs) { |
| FunctionFrame frame = FindFunctionFrame("R.func_attr"); |
| for (const auto& [key, value] : attrs) { |
| if (key == tvm::attr::kGlobalSymbol && frame->is_private.value_or(Bool(false))->value) { |
| LOG(FATAL) << "ValueError: " |
| << "A private function may not have the kGlobalSymbol (\"" |
| << tvm::attr::kGlobalSymbol << "\") attribute. " |
| << "However, a private function specified the global symbol as " << value; |
| } |
| if (auto prev = frame->attrs.Get(key)) { |
| LOG(FATAL) << "ValueError: " |
| << "Duplicate R.func_attr annotation for key = \"" << key << "\". " |
| << "Previous value was " << prev.value() << ", with later definition as " << value; |
| } else { |
| frame->attrs.Set(key, value); |
| } |
| } |
| } |
| |
| void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) { |
| FunctionFrame frame = FindFunctionFrame("R.func_ret_struct_info"); |
| if (frame->ret_struct_info.defined()) { |
| LOG(FATAL) << "ValueError: Duplicate function return struct info, previous one is:\n " |
| << frame->ret_struct_info.value(); |
| } |
| frame->ret_struct_info = ret_sinfo; |
| } |
| |
| void FuncRetValue(const tvm::relax::Expr& value) { |
| // Step 0. Normalize the value. |
| const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); |
| tvm::relax::Expr normalized_value = block_builder->Normalize(value); |
| |
| IRBuilder ir_builder = IRBuilder::Current(); |
| |
| // Step 1. The current Relax TVMScript syntax only allows function return appearing at the end of |
| // a function body. Therefore if there is any unended block frame when dealing with function |
| // return, we should end the block frame. |
| |
| if (auto opt = ir_builder->GetLastFrame<BindingBlockFrame>()) { |
| auto block_frame = opt.value(); |
| for (const auto& var : tvm::relax::FreeVars(normalized_value)) { |
| if (var->IsInstance<tvm::relax::DataflowVarNode>()) { |
| block_frame->output_vars.push_back(var); |
| } |
| } |
| } |
| // Step 2. Add the output value to the function frame. |
| FunctionFrame frame = FindFunctionFrame("return"); |
| CHECK(!frame->output.defined()) |
| << "ValueError: " |
| << "Relax functions do not support multiple return statement. " |
| << "However, return of " << normalized_value << " occurred after a return of " |
| << frame->output << ". " |
| << "Please make sure function only has a single return statement, " |
| << "which appears at the end of function."; |
| |
| frame->output = std::move(normalized_value); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("script.ir_builder.relax.Function", Function) |
| .def("script.ir_builder.relax.Arg", Arg) |
| .def("script.ir_builder.relax.FuncName", FuncName) |
| .def("script.ir_builder.relax.FuncAttrs", FuncAttrs) |
| .def("script.ir_builder.relax.FuncRetStructInfo", FuncRetStructInfo) |
| .def("script.ir_builder.relax.FuncRetValue", FuncRetValue); |
| } |
| |
| ///////////////////////////// BindingBlock ////////////////////////////// |
| |
| BindingBlockFrame Dataflow() { |
| ObjectPtr<BindingBlockFrameNode> n = ffi::make_object<BindingBlockFrameNode>(); |
| n->is_dataflow = true; |
| n->block_ended = false; |
| return BindingBlockFrame(n); |
| } |
| |
| BindingBlockFrame BindingBlock() { |
| ObjectPtr<BindingBlockFrameNode> n = ffi::make_object<BindingBlockFrameNode>(); |
| n->is_dataflow = false; |
| n->block_ended = false; |
| return BindingBlockFrame(n); |
| } |
| |
| void DataflowBlockOutput(const ffi::Array<tvm::relax::Var>& vars) { |
| // Step 1. Check that we're in a Dataflow block that is not ended. |
| ffi::Optional<BindingBlockFrame> block_frame = |
| IRBuilder::Current()->GetLastFrame<BindingBlockFrame>(); |
| CHECK(block_frame.defined() && block_frame.value()->is_dataflow) |
| << "ValueError: `R.output` should appear inside a dataflow block. However, the current " |
| "innermost block is not a dataflow block."; |
| CHECK(!block_frame.value()->block_ended) |
| << "ValueError: It is not allowed for a dataflow block to have multiple output operation."; |
| |
| // Step 2. Mark the block frame ended of construction, so that any followup binding after this |
| // mark in the dataflow block will lead to an error. |
| block_frame.value()->block_ended = true; |
| |
| // Step 3. All the output variables must be global variables and must be emitted by this dataflow |
| // block. |
| const ffi::Array<tvm::relax::Var>& emitted_vars = block_frame.value()->emitted_vars; |
| for (const tvm::relax::Var& var : vars) { |
| CHECK(std::find(emitted_vars.begin(), emitted_vars.end(), var) != emitted_vars.end()) |
| << "ValueError: An output variable is not emitted by this dataflow block. Please make sure " |
| "all dataflow block output variables are emitted exactly by this block."; |
| block_frame.value()->output_vars.push_back(var); |
| } |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("script.ir_builder.relax.Dataflow", Dataflow) |
| .def("script.ir_builder.relax.BindingBlock", BindingBlock) |
| .def("script.ir_builder.relax.DataflowBlockOutput", DataflowBlockOutput); |
| } |
| |
| /////////////////////////////// Bindings /////////////////////////////// |
| |
| tvm::relax::Var Emit(const tvm::relax::Expr& expr, |
| const ffi::Optional<tvm::relax::StructInfo>& annotate_struct_info) { |
| using tvm::relax::GetStructInfo; |
| BindingBlockFrame block_frame = CheckBindingBlockFrameExistAndUnended(); |
| const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); |
| if (annotate_struct_info.defined()) { |
| const auto& sinfo = annotate_struct_info.value(); |
| if (!expr->struct_info_.defined()) { |
| UpdateStructInfo(expr, sinfo); |
| } else { |
| CHECK(StructInfoBaseCheck(sinfo, GetStructInfo(expr)) != tvm::relax::BaseCheckResult::kFailL0) |
| << "Invalid annotation. Got rhs value struct info: " << GetStructInfo(expr) |
| << ", given struct info: " << sinfo; |
| } |
| } |
| tvm::relax::Var var = block_builder->Emit(expr); |
| block_frame->emitted_vars.push_back(var); |
| return var; |
| } |
| |
| tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, |
| const tvm::relax::StructInfo& struct_info) { |
| BindingBlockFrame block_frame = CheckBindingBlockFrameExistAndUnended(); |
| const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); |
| |
| tvm::relax::Var var = block_builder->EmitMatchCast(value, struct_info); |
| block_frame->emitted_vars.push_back(var); |
| return var; |
| } |
| |
| tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding) { |
| BindingBlockFrame block_frame = CheckBindingBlockFrameExistAndUnended(); |
| const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); |
| block_builder->EmitNormalized(binding); |
| block_frame->emitted_vars.push_back(binding->var); |
| return binding->var; |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("script.ir_builder.relax.Emit", Emit) |
| .def("script.ir_builder.relax.EmitMatchCast", EmitMatchCast) |
| .def("script.ir_builder.relax.EmitVarBinding", EmitVarBinding); |
| } |
| |
| /////////////////////////////// SeqExpr /////////////////////////////// |
| |
| SeqExprFrame SeqExpr() { |
| ObjectPtr<SeqExprFrameNode> n = ffi::make_object<SeqExprFrameNode>(); |
| return SeqExprFrame(n); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("script.ir_builder.relax.SeqExpr", SeqExpr); |
| } |
| |
| ///////////////////////////// If Then Else ///////////////////////////// |
| |
| IfFrame If(tvm::relax::Expr condition) { |
| ObjectPtr<IfFrameNode> n = ffi::make_object<IfFrameNode>(); |
| n->condition = condition; |
| n->then_expr = std::nullopt; |
| n->else_expr = std::nullopt; |
| return IfFrame(n); |
| } |
| |
| ThenFrame Then() { |
| ObjectPtr<ThenFrameNode> n = ffi::make_object<ThenFrameNode>(); |
| return ThenFrame(n); |
| } |
| |
| ElseFrame Else() { |
| ObjectPtr<ElseFrameNode> n = ffi::make_object<ElseFrameNode>(); |
| return ElseFrame(n); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("script.ir_builder.relax.If", If) |
| .def("script.ir_builder.relax.Then", Then) |
| .def("script.ir_builder.relax.Else", Else); |
| } |
| |
| } // namespace relax |
| } // namespace ir_builder |
| } // namespace script |
| } // namespace tvm |