| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file src/relax/transform/gradient_simplifier.cc |
| * \brief Simplify patterns generated by the gradient pass. Only used in gradient.cc. |
| * \sa tvm/relax/transform/gradient.cc |
| * |
| * We will simplify these patterns: |
| * (transpose means use permute_dims to transpose the last two dimensions) |
| * 1. Forward is: out = matmul(a, transpose(b)) |
| * Then backward is: |
| * grad_a = matmul(grad_out, transpose(transpose(b))) |
| * grad_b = transpose(matmul(transpose(a), grad_out)) |
| * We will simplify it to: |
| * grad_a = matmul(grad_out, b) |
| * grad_b = matmul(transpose(grad_out), a) |
| * 2. Forward is: out = matmul(transpose(a), b) |
| * Then backward is: |
| * grad_a = transpose(matmul(grad_out, transpose(b))) |
| * grad_b = matmul(transpose(transpose(a)), grad_out) |
| * We will simplify it to: |
| * grad_a = matmul(b, transpose(grad_out)) |
| * grad_b = matmul(a, grad_out) |
| * 3. Forward is: out = matmul(transpose(a), transpose(b)) |
| * Then backward is: |
| * grad_a = transpose(matmul(grad_out, transpose(transpose(b)))) |
| * grad_b = transpose(matmul(transpose(transpose(a)), grad_out)) |
| * We will simplify it to: |
| * grad_a = matmul(transpose(b), transpose(grad_out)) |
| * grad_b = matmul(transpose(grad_out), transpose(a)) |
| */ |
| |
| #include "gradient_simplifier.h" |
| |
| #include <tvm/relax/analysis.h> |
| #include <tvm/relax/attrs/manipulate.h> |
| #include <tvm/relax/expr.h> |
| #include <tvm/relax/expr_functor.h> |
| |
| #include "../op/tensor/linear_algebra.h" |
| #include "../op/tensor/manipulate.h" |
| |
| namespace tvm { |
| namespace relax { |
| |
| /*! |
| * \brief Simplify patterns generated by the gradient pass. Especially, simplify the matmul |
| * patterns. |
| */ |
| class GradientSimplifier : private ExprMutator { |
| public: |
| /*! |
| * \brief Collect all variables that needs to be checkpointed, and remove the start_checkpoint |
| * and the end_checkpoint bindings. |
| * |
| * \param func The original function |
| * \return The function with all start_checkpoint and end_checkpoint bindings removed, and a |
| * VarIdSet containing all checkpointed vars. |
| */ |
| static Function Transform(const Function& func) { |
| return Downcast<Function>(RemoveAllUnused(GradientSimplifier().VisitExpr(func))); |
| } |
| |
| private: |
| static bool IsTransposeOp(const CallNode* call_node) { |
| if (call_node->op != Op::Get("relax.permute_dims")) { |
| return false; |
| } |
| auto sinfo = MatchStructInfo<TensorStructInfo>(call_node->args[0]); |
| if (!sinfo) { |
| return false; |
| } |
| auto ndim = sinfo.value()->ndim; |
| if (ndim == kUnknownNDim || ndim == 1) { |
| return false; |
| } |
| if (!call_node->attrs.as<PermuteDimsAttrs>()->axes.defined()) { |
| return ndim == 2; |
| } |
| auto axes = call_node->attrs.as<PermuteDimsAttrs>()->axes.value(); |
| TVM_FFI_ICHECK(static_cast<int>(axes.size()) == ndim); |
| for (int i = 0; i < ndim - 2; ++i) { |
| if (axes[i] != i) { |
| return false; |
| } |
| } |
| return axes[ndim - 2] == ndim - 1 && axes[ndim - 1] == ndim - 2; |
| } |
| |
| // Return permute_dims(expr). Generate the axes needed. |
| static Expr GetTransposeOf(const Expr& expr) { |
| auto sinfo = MatchStructInfo<TensorStructInfo>(expr); |
| TVM_FFI_ICHECK(sinfo); |
| auto ndim = sinfo.value()->ndim; |
| if (ndim == 1) { |
| return expr; |
| } |
| auto axes = ffi::Array<Integer>(); |
| for (int i = 0; i < ndim - 2; ++i) { |
| axes.push_back(i); |
| } |
| axes.push_back(ndim - 1); |
| axes.push_back(ndim - 2); |
| return permute_dims(expr, axes); |
| } |
| |
| // If expr is already in the form of permute_dims in previous bindings, return the input of the |
| // permute_dims op |
| // Else, return permute_dims(expr) |
| Expr GetTransposeAccordingToCtx(const Expr& expr) { |
| if (!expr->IsInstance<VarNode>()) { |
| return GetTransposeOf(expr); |
| } |
| auto prev_expr = builder_->LookupBinding(Downcast<Var>(expr)); |
| if (!prev_expr || !prev_expr->IsInstance<CallNode>()) { |
| return GetTransposeOf(expr); |
| } |
| auto prev_call_node = prev_expr.as<CallNode>(); |
| if (!IsTransposeOp(prev_call_node)) { |
| return GetTransposeOf(expr); |
| } |
| return prev_call_node->args[0]; |
| } |
| |
| void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) { |
| auto result = ExprMutator::VisitExpr(ffi::GetRef<Expr>(call_node)); |
| auto new_call_node = result.as<CallNode>(); |
| auto reemit_and_return = [&]() { |
| ReEmitBinding(binding, result); |
| return; |
| }; |
| |
| if (!IsTransposeOp(new_call_node)) { |
| return reemit_and_return(); |
| } |
| |
| auto arg = new_call_node->args[0]; |
| if (!arg->IsInstance<VarNode>()) { |
| return reemit_and_return(); |
| } |
| |
| auto prev_expr = builder_->LookupBinding(Downcast<Var>(arg)); |
| if (!prev_expr || !prev_expr->IsInstance<CallNode>()) { |
| return reemit_and_return(); |
| } |
| |
| auto prev_call_node = prev_expr.as<CallNode>(); |
| if (IsTransposeOp(prev_call_node)) { |
| // rewrite rule #1: permute_dims(permute_dims(a)) -> a |
| if (prev_call_node->args[0]->IsInstance<VarNode>()) { |
| var_remap_[binding->var->vid] = Downcast<Var>(prev_call_node->args[0]); |
| return; |
| } else { |
| return reemit_and_return(); |
| } |
| } else if (prev_call_node->op == Op::Get("relax.matmul")) { |
| // rewrite rule #2: permute_dims(matmul(a, b)) -> matmul(permute_dims(b), permute_dims(a)) |
| // Should "a" or "b" already be in the form of "permute_dims", the redundant permute_dims |
| // operation should be eliminated |
| |
| // Skip matmuls with 1-dim input because in these cases we cannot simply transpose the input |
| auto a_dim = MatchStructInfo<TensorStructInfo>(prev_call_node->args[0]).value()->ndim; |
| auto b_dim = MatchStructInfo<TensorStructInfo>(prev_call_node->args[1]).value()->ndim; |
| if (a_dim == 1 || b_dim == 1) { |
| return reemit_and_return(); |
| } |
| |
| auto a = GetTransposeAccordingToCtx(prev_call_node->args[0]); |
| auto b = GetTransposeAccordingToCtx(prev_call_node->args[1]); |
| result = |
| ExprMutator::VisitExpr(matmul(b, a, prev_call_node->attrs.as<MatmulAttrs>()->out_dtype)); |
| ReEmitBinding(binding, result); |
| return; |
| } else { |
| return reemit_and_return(); |
| } |
| } |
| }; |
| |
| Function SimplifyGradient(const Function& func) { return GradientSimplifier::Transform(func); } |
| |
| } // namespace relax |
| } // namespace tvm |