| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/relax/analysis.h> |
| #include <tvm/relax/expr_functor.h> |
| #include <tvm/relax/transform.h> |
| #include <tvm/relax/utils.h> |
| |
| #include <algorithm> |
| #include <optional> |
| #include <tuple> |
| |
| #include "utils.h" |
| |
| namespace tvm { |
| namespace relax { |
| |
| namespace { |
| |
| template <typename T> |
| using PSet = std::unordered_set<T, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>; |
| |
| template <typename T, typename U> |
| using PMap = std::unordered_map<T, U, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>; |
| |
| /* \brief Describes the modifications to be made for a function */ |
| struct CalleeAnalysis { |
| /* \brief The updated private function */ |
| Function func; |
| |
| /* \brief A function that updates the callsite arguments |
| * |
| * \param The arguments used to call the original function |
| * |
| * \return The arguments to be used for the modified function |
| */ |
| std::function<ffi::Array<Expr>(ffi::Array<Expr>)> arg_updater; |
| }; |
| |
| std::optional<CalleeAnalysis> AnalyzeCallee(Function func) { |
| bool is_exposed = func->attrs.GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).has_value(); |
| if (is_exposed) return std::nullopt; |
| |
| auto free_relax_vars = [&]() -> PSet<Var> { |
| auto array_free_vars = FreeVars(func->body); |
| return {array_free_vars.begin(), array_free_vars.end()}; |
| }(); |
| |
| std::vector<bool> parameter_mask; |
| parameter_mask.reserve(func->params.size()); |
| |
| ffi::Array<Var> params; |
| for (const auto& param : func->params) { |
| bool is_used = free_relax_vars.count(param); |
| parameter_mask.push_back(is_used); |
| if (is_used) { |
| params.push_back(param); |
| } |
| } |
| |
| if (func->params.size() == params.size()) { |
| // Early bail-out for the common case where the function uses all |
| // of its parameters. |
| return std::nullopt; |
| } |
| |
| // Even if a parameter is unused, it may provide definitions for |
| // symbolic variables. We still want to remove the relax variable |
| // to reduce computational steps in the parent, but we need to |
| // provide the symbolic variables the other steps. |
| auto defined_tir_params = [&]() -> PSet<tirx::Var> { |
| auto param_sinfo = |
| TupleStructInfo(params.Map([](const auto& var) { return GetStructInfo(var); })); |
| auto arr = DefinableTIRVarsInStructInfo(param_sinfo); |
| return {arr.begin(), arr.end()}; |
| }(); |
| |
| // Use an array to define the order of the symbolic variables |
| ffi::Array<tirx::Var> free_tir_vars; |
| for (const auto& tir_var : FreeSymbolicVars(func->body)) { |
| if (!defined_tir_params.count(tir_var)) { |
| free_tir_vars.push_back(tir_var); |
| } |
| } |
| |
| for (const auto& tir_var : free_tir_vars) { |
| Var relax_var("param_" + tir_var->name_hint, PrimStructInfo(tir_var)); |
| params.push_back(relax_var); |
| } |
| |
| FuncStructInfo new_sinfo(params.Map([](const auto& var) { return GetStructInfo(var); }), |
| func->ret_struct_info, |
| Downcast<FuncStructInfo>(func->struct_info_)->purity); |
| |
| auto arg_updater = [parameter_mask, old_relax_params = func->params, |
| free_tir_vars](ffi::Array<Expr> old_args) -> ffi::Array<Expr> { |
| TVM_FFI_ICHECK_EQ(old_args.size(), parameter_mask.size()) |
| << "Call provides " << old_args.size() << ", but the callee accepts " |
| << parameter_mask.size() << " parameters"; |
| |
| ffi::Array<Expr> new_args; |
| for (size_t i = 0; i < old_args.size(); i++) { |
| if (parameter_mask.at(i)) { |
| new_args.push_back(old_args[i]); |
| } |
| } |
| |
| if (free_tir_vars.size()) { |
| ffi::Map<Var, Expr> old_binding; |
| for (size_t i = 0; i < old_relax_params.size(); i++) { |
| old_binding.Set(old_relax_params[i], old_args[i]); |
| } |
| arith::Analyzer analyzer; |
| auto tir_binding = InferSymbolicVarMap(old_binding, &analyzer); |
| |
| for (const auto& tir_var : free_tir_vars) { |
| new_args.push_back(PrimValue(tir_binding.at(tir_var))); |
| } |
| } |
| |
| return new_args; |
| }; |
| |
| auto write_ptr = func.CopyOnWrite(); |
| write_ptr->params = params; |
| write_ptr->struct_info_ = new_sinfo; |
| |
| return CalleeAnalysis{func, arg_updater}; |
| } |
| |
| class CallSiteMutator : public ExprMutator { |
| public: |
| explicit CallSiteMutator(PMap<GlobalVar, std::function<Call(Call)>> callsite_updaters) |
| : callsite_updaters_(callsite_updaters) {} |
| |
| using ExprMutator::VisitExpr_; |
| |
| Expr VisitExpr_(const FunctionNode* op) override { |
| auto node = ExprMutator::VisitExpr_(op); |
| |
| // If a function was modified, that means it called into a private |
| // function that now takes a reduced number of arguments. Some |
| // bindings in the calling scope, previously used to define those |
| // unused arguments, may be able to be removed as a result. |
| if (node.get() != op) { |
| node = RemoveAllUnused(node); |
| } |
| return node; |
| } |
| |
| Expr VisitExpr_(const CallNode* op) override { |
| auto node = Downcast<Call>(ExprMutator::VisitExpr_(op)); |
| |
| if (auto gvar = node->op.as<GlobalVar>()) { |
| if (auto it = callsite_updaters_.find(gvar.value()); it != callsite_updaters_.end()) { |
| node = it->second(std::move(node)); |
| } |
| } |
| |
| return node; |
| } |
| |
| PMap<GlobalVar, std::function<Call(Call)>> callsite_updaters_; |
| }; |
| |
| } // namespace |
| |
| namespace transform { |
| |
| Pass RemoveUnusedParameters() { |
| auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule { |
| PMap<GlobalVar, std::function<Call(Call)>> callsite_updaters; |
| |
| { |
| IRModule new_callees; |
| |
| for (const auto& [gvar, base_func] : mod->functions) { |
| if (auto func = base_func.as<Function>()) { |
| if (auto callee_res = AnalyzeCallee(func.value())) { |
| auto new_func = callee_res->func; |
| GlobalVar new_gvar(gvar->name_hint); |
| new_gvar->struct_info_ = new_func->struct_info_; |
| new_callees->Add(new_gvar, new_func); |
| |
| callsite_updaters[gvar] = [old_gvar = gvar, new_gvar, |
| arg_updater = callee_res->arg_updater](Call call) -> Call { |
| TVM_FFI_CHECK(call->op.same_as(old_gvar), InternalError) |
| << "Updater should be applied to " << old_gvar << ", but was applied to " |
| << call->op; |
| auto write_ptr = call.CopyOnWrite(); |
| write_ptr->op = new_gvar; |
| write_ptr->args = arg_updater(call->args); |
| return call; |
| }; |
| } |
| } |
| } |
| |
| if (callsite_updaters.empty()) { |
| return mod; |
| } |
| auto write_ptr = mod.CopyOnWrite(); |
| |
| // Remove any private subroutines that have unused parameters, |
| // then add the updated versions. The new private functions |
| // have the same name, but require a new GlobalVar to hold the |
| // updated StructInfo. As a result, calling `Update()` without |
| // first calling `Remove()` introduce a duplicate name and |
| // produce an error. |
| for (const auto& it : callsite_updaters) { |
| write_ptr->Remove(it.first); |
| } |
| write_ptr->Update(new_callees); |
| } |
| |
| CallSiteMutator mutator(std::move(callsite_updaters)); |
| |
| IRModule caller_updates; |
| |
| for (const auto& [gvar, base_func] : mod->functions) { |
| if (auto func = base_func.as<Function>()) { |
| auto mutated = Downcast<Function>(mutator.VisitExpr(func.value())); |
| if (!mutated.same_as(base_func)) { |
| caller_updates->Add(gvar, mutated); |
| } |
| } |
| } |
| |
| if (caller_updates->functions.size()) { |
| mod.CopyOnWrite()->Update(caller_updates); |
| } |
| return mod; |
| }; |
| return CreateModulePass(pass_func, 0, "RemoveUnusedParameters", {}); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("relax.transform.RemoveUnusedParameters", RemoveUnusedParameters); |
| } |
| |
| } // namespace transform |
| |
| } // namespace relax |
| } // namespace tvm |