| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * |
| * \file simplify_fc_transpose.cc |
| * |
| * \brief Mutate ```y = nn.dense(x, tranpose(w, [1, 0]))``` to |
| * ```y = nn.dense(x, wt)``` |
| */ |
| #include <tvm/ir/expr.h> |
| #include <tvm/relay/analysis.h> |
| #include <tvm/relay/attrs/nn.h> |
| #include <tvm/relay/attrs/transform.h> |
| #include <tvm/relay/expr_functor.h> |
| #include <tvm/relay/op_attr_types.h> |
| #include <tvm/relay/transform.h> |
| |
| #include <unordered_map> |
| #include <unordered_set> |
| |
| namespace tvm { |
| namespace relay { |
| |
| // Find name of weight in ```y = nn.dense(x, tranpose(w, [1, 0]))``` |
| class FCTransposeVisitor : private ExprVisitor { |
| public: |
| FCTransposeVisitor() : dense_op_(Op::Get("nn.dense")), transpose_op_(Op::Get("transpose")) {} |
| |
| Array<String> Search(const Expr& expr) { |
| VisitExpr(expr); |
| return memo_; |
| } |
| |
| private: |
| void VisitExpr_(const CallNode* n) final { |
| if (n->op == dense_op_) { |
| const auto weight = n->args[1].as<CallNode>(); |
| if (weight) { |
| if (weight->op == transpose_op_) { |
| if (weight->args[0].as<VarNode>()) { |
| const auto arg = weight->args[0].as<VarNode>(); |
| memo_.push_back(arg->name_hint()); |
| } |
| } |
| } |
| } |
| for (const auto& arg : n->args) { |
| VisitExpr(arg); |
| } |
| } |
| |
| const Op& dense_op_; |
| const Op& transpose_op_; |
| Array<String> memo_; |
| }; // SearchDenseOpWeight |
| |
| Array<String> SearchFCTranspose(const Expr& e) { return FCTransposeVisitor().Search(e); } |
| |
| TVM_REGISTER_GLOBAL("relay.analysis.search_fc_transpose").set_body_typed(SearchFCTranspose); |
| |
| // Mutate ```y = nn.dense(x, tranpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)``` |
| class FCTransposeMutator : public ExprRewriter { |
| public: |
| explicit FCTransposeMutator(const Array<ObjectRef>& target_weights) |
| : dense_op_(Op::Get("nn.dense")), transpose_op_(Op::Get("transpose")) { |
| for (size_t i = 0; i < target_weights.size(); ++i) { |
| CHECK(target_weights[i]->IsInstance<runtime::StringObj>()); |
| std::string k = target_weights[i].as<runtime::StringObj>()->data; |
| target_weights_.emplace(k); |
| } |
| } |
| |
| Expr Rewrite_(const CallNode* pre, const Expr& post) override { |
| if (pre->op == dense_op_) { |
| const auto data = post.as<CallNode>()->args[0]; |
| const auto weight = pre->args[1].as<CallNode>(); |
| if (weight) { |
| if (weight->op == transpose_op_) { |
| const auto arg = weight->args[0]; |
| if (arg.as<VarNode>()) { |
| const auto& arg_node = arg.as<VarNode>(); |
| CHECK_GT(target_weights_.count(arg_node->name_hint()), 0); |
| const auto& tt = arg_node->type_annotation.as<TensorTypeNode>(); |
| auto wt_type = TensorType({tt->shape[1], tt->shape[0]}, tt->dtype); |
| Var wt(arg_node->name_hint() + ".T", wt_type); |
| return Call(dense_op_, {data, wt}, pre->attrs, pre->type_args); |
| } |
| } |
| } |
| } |
| return post; |
| } |
| |
| private: |
| // Cached op |
| const Op& dense_op_; |
| const Op& transpose_op_; |
| std::unordered_set<std::string> target_weights_; |
| }; // class DenseToSparseDenseAlter |
| |
| Expr SimplifyFCTranspose(const Expr& e, const Array<ObjectRef>& target_weights) { |
| auto rewriter = FCTransposeMutator(target_weights); |
| return PostOrderRewrite(e, &rewriter); |
| } |
| |
| namespace transform { |
| |
| Pass SimplifyFCTranspose(const Array<ObjectRef>& target_weights) { |
| runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
| [=](Function f, IRModule m, PassContext pc) { |
| // Remove FreeVar warning |
| auto f0 = Downcast<Function>(SimplifyFCTranspose(f, target_weights)); |
| Array<Var> wt_params = FreeVars(f0); |
| auto f1 = Function(wt_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); |
| Array<Var> params = FreeVars(f1); |
| for (const auto& var : wt_params) { |
| params.push_back(var); |
| } |
| return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); |
| }; |
| return CreateFunctionPass(pass_func, 4, "SimplifyFCTranspose", {"DeadCodeElimination"}); |
| } |
| |
| TVM_REGISTER_GLOBAL("relay._transform.SimplifyFCTranspose").set_body_typed(SimplifyFCTranspose); |
| |
| } // namespace transform |
| |
| } // namespace relay |
| } // namespace tvm |