| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/ir/function.h> |
| #include <tvm/relax/expr.h> |
| #include <tvm/relax/expr_functor.h> |
| #include <tvm/relax/transform.h> |
| #include <tvm/relax/type.h> |
| #include <tvm/tirx/op.h> |
| |
| #include <tuple> |
| #include <utility> |
| |
| namespace tvm { |
| namespace relax { |
| |
| void MatchSymbolicVar(const Expr& arg, const Expr& constant, |
| ffi::Map<tirx::Var, PrimExpr>* symbolic_var_map, arith::Analyzer* analyzer_) { |
| auto opt_arg_sinfo = MatchStructInfo<TensorStructInfo>(arg); |
| TVM_FFI_ICHECK(opt_arg_sinfo) |
| << "The struct info of the bound parameter is expected to be TensorStructInfo, but got: " |
| << GetStructInfo(arg); |
| auto opt_const_sinfo = MatchStructInfo<TensorStructInfo>(constant); |
| // As the constant is generated by internal codes, we use TVM_FFI_ICHECK here. |
| TVM_FFI_ICHECK(opt_const_sinfo) |
| << "The struct info of the bound weight is expected to be TensorStructInfo, but got: " |
| << GetStructInfo(constant); |
| |
| TensorStructInfo arg_sinfo = opt_arg_sinfo.value(); |
| TensorStructInfo const_sinfo = opt_const_sinfo.value(); |
| TVM_FFI_ICHECK(!const_sinfo->IsUnknownDtype()); |
| TVM_FFI_ICHECK(!const_sinfo->IsUnknownNdim()); |
| TVM_FFI_ICHECK(const_sinfo->shape.defined()); |
| |
| // dtype mismatch |
| if (!arg_sinfo->IsUnknownDtype() && arg_sinfo->dtype != const_sinfo->dtype) { |
| TVM_FFI_THROW(InternalError) << "The dtype of the bound parameter is expected to be " |
| << arg_sinfo->dtype << ", but got: " << const_sinfo->dtype; |
| } |
| // ndim mismatch |
| if (!arg_sinfo->IsUnknownNdim() && arg_sinfo->ndim != const_sinfo->ndim) { |
| TVM_FFI_THROW(InternalError) << "The ndim of the bound parameter is expected to be " |
| << arg_sinfo->ndim << ", but got: " << const_sinfo->ndim; |
| } |
| if (!arg_sinfo->shape.defined()) return; |
| const auto* arg_shape = arg_sinfo->shape.value().as<ShapeExprNode>(); |
| const auto* const_shape = const_sinfo->shape.value().as<ShapeExprNode>(); |
| |
| TVM_FFI_ICHECK(arg_shape && const_shape) |
| << "The shape of the bound parameter and weight is expected to be ShapeExprNode for now"; |
| |
| for (int i = 0; i < arg_sinfo->ndim; ++i) { |
| const PrimExpr& const_dim = const_shape->values[i]; |
| TVM_FFI_ICHECK(tirx::is_const_int(const_dim)); |
| if (const auto* shape_var = arg_shape->values[i].as<tirx::VarNode>()) { |
| auto it = symbolic_var_map->find(ffi::GetRef<tirx::Var>(shape_var)); |
| if (it == symbolic_var_map->end()) { |
| symbolic_var_map->Set(ffi::GetRef<tirx::Var>(shape_var), const_dim); |
| } else { |
| TVM_FFI_ICHECK(analyzer_->CanProveEqual((*it).second, const_dim)) |
| << "The shape of the bound parameter is expected to be " << (*it).second |
| << ", but got: " << const_dim; |
| } |
| } |
| } |
| } |
| |
| std::tuple<ffi::Map<Var, Expr>, ffi::Map<tirx::Var, PrimExpr>> NormalizeBindings( |
| const Function& func, const ffi::Map<Any, ObjectRef>& untyped_params) { |
| TVM_FFI_ICHECK(func.defined()); |
| TVM_FFI_ICHECK(untyped_params.defined()); |
| |
| // Map from string to the variable(s) with that name. |
| std::unordered_map<std::string, ffi::Array<relax::Var>> string_lookup; |
| std::unordered_set<const relax::VarNode*> var_set; |
| for (const auto& param : func->params) { |
| string_lookup[param->name_hint()].push_back(param); |
| var_set.insert(param.get()); |
| } |
| |
| ffi::Map<relax::Var, relax::Expr> relax_var_remap; |
| |
| auto normalize_key = [&](ffi::Any obj) -> relax::Var { |
| if (auto opt_str = obj.as<ffi::String>()) { |
| std::string str = opt_str.value(); |
| auto it = string_lookup.find(str); |
| TVM_FFI_ICHECK(it != string_lookup.end()) |
| << "Function does not have parameter with name \"" << str << "\". " |
| << "Function parameters are named " |
| << func->params.Map([](const auto& param) { return param->name_hint(); }); |
| TVM_FFI_ICHECK_EQ(it->second.size(), 1) |
| << "Function contains multiple parameters with name \"" << str << "\". " |
| << "The Relax variables " << it->second << " are all named \"" << str << "\""; |
| auto var = it->second[0]; |
| TVM_FFI_ICHECK(!relax_var_remap.count(var)) |
| << "Remap of variable " << var << " was defined multiple times"; |
| |
| return var; |
| } else if (auto opt_var = obj.as<relax::Var>()) { |
| auto var = opt_var.value(); |
| TVM_FFI_ICHECK(!relax_var_remap.count(var)) |
| << "Remap of variable " << var << " was defined multiple times"; |
| TVM_FFI_ICHECK(var_set.count(var.get())) |
| << "Function does not use Relax variable " << var << " as a parameter. " |
| << "Function parameters are " << func->params; |
| return var; |
| } else { |
| TVM_FFI_THROW(InternalError) |
| << "Expected bound parameter to be a relax::Var, " |
| << " or a string that uniquely identifies a relax::Var param within the function. " |
| << "However, received object " << obj << " of type " << obj.GetTypeKey(); |
| } |
| }; |
| auto normalize_value = [&](ffi::Any obj) -> relax::Expr { |
| if (auto opt = obj.as<relax::Expr>()) { |
| return opt.value(); |
| } else if (auto opt = obj.as<runtime::Tensor>()) { |
| return Constant(opt.value()); |
| } else { |
| TVM_FFI_THROW(InternalError) |
| << "Cannot coerce object of type " << obj.GetTypeKey() << " into relax expression"; |
| } |
| }; |
| |
| for (const auto& [key, value] : untyped_params) { |
| relax_var_remap.Set(normalize_key(key), normalize_value(value)); |
| } |
| |
| arith::Analyzer analyzer; |
| ffi::Map<tirx::Var, PrimExpr> symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); |
| |
| // for (const auto& [bind_param, bind_expr] : relax_var_remap) { |
| // MatchSymbolicVar(bind_param, bind_expr, &symbolic_var_map, &analyzer); |
| // } |
| |
| return {relax_var_remap, symbolic_var_map}; |
| } |
| |
| /*! |
| * \brief Bind params to function by using name |
| * \param func Relax function |
| * \param params params dict |
| * \return Function |
| */ |
| Function FunctionBindParams(Function func, const ffi::Map<Any, ObjectRef>& untyped_params) { |
| auto [bind_dict, symbolic_var_map] = NormalizeBindings(func, untyped_params); |
| |
| Expr bound_expr = Bind(func, bind_dict, symbolic_var_map); |
| return Downcast<Function>(bound_expr); |
| } |
| |
| /*! |
| * \brief Bind params to a specific function in a module |
| * \param m The module |
| * \param func_name The name of the specific function |
| * \param param The param dict |
| * \return The module after binding params. |
| */ |
| IRModule BindParam(IRModule m, ffi::String func_name, ffi::Map<Any, ObjectRef> bind_params) { |
| IRModuleNode* new_module = m.CopyOnWrite(); |
| ffi::Map<GlobalVar, BaseFunc> functions = m->functions; |
| for (const auto& func_pr : functions) { |
| if (const auto* relax_f = func_pr.second.as<FunctionNode>()) { |
| if (relax_f->GetLinkageType() == LinkageType::kExternal) { |
| // Use global_symbol if it's external linkage |
| ffi::Optional<ffi::String> gsymbol = |
| relax_f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol); |
| if (gsymbol.has_value() && gsymbol.value() == func_name) { |
| Function f_after_bind = FunctionBindParams(ffi::GetRef<Function>(relax_f), bind_params); |
| new_module->Update(func_pr.first, f_after_bind); |
| } |
| } else { |
| // Use global var's name_hint if it's internal linkage |
| if (func_pr.first->name_hint == func_name) { |
| Function f_after_bind = FunctionBindParams(ffi::GetRef<Function>(relax_f), bind_params); |
| new_module->Update(func_pr.first, f_after_bind); |
| } |
| } |
| } |
| } |
| return ffi::GetRef<IRModule>(new_module); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("relax.FunctionBindParams", FunctionBindParams); |
| } |
| |
| namespace transform { |
| |
| Pass BindParams(ffi::String func_name, ffi::Map<Any, ObjectRef> params) { |
| auto pass_func = [=](IRModule mod, PassContext pc) { |
| return BindParam(std::move(mod), func_name, params); |
| }; |
| return CreateModulePass(pass_func, 0, "BindParams", {}); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("relax.transform.BindParams", BindParams); |
| } |
| |
| } // namespace transform |
| |
| } // namespace relax |
| } // namespace tvm |