blob: 32eb6707dfe20e0731179c07dfd580ef24430f1f [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relax/transform/gradient_simplifier.cc
* \brief Simplify patterns generated by the gradient pass. Only used in gradient.cc.
* \sa tvm/relax/transform/gradient.cc
*
* We will simplify these patterns:
* (transpose means use permute_dims to transpose the last two dimensions)
* 1. Forward is: out = matmul(a, transpose(b))
* Then backward is:
* grad_a = matmul(grad_out, transpose(transpose(b)))
* grad_b = transpose(matmul(transpose(a), grad_out))
* We will simplify it to:
* grad_a = matmul(grad_out, b)
* grad_b = matmul(transpose(grad_out), a)
* 2. Forward is: out = matmul(transpose(a), b)
* Then backward is:
* grad_a = transpose(matmul(grad_out, transpose(b)))
* grad_b = matmul(transpose(transpose(a)), grad_out)
* We will simplify it to:
* grad_a = matmul(b, transpose(grad_out))
* grad_b = matmul(a, grad_out)
* 3. Forward is: out = matmul(transpose(a), transpose(b))
* Then backward is:
* grad_a = transpose(matmul(grad_out, transpose(transpose(b))))
* grad_b = transpose(matmul(transpose(transpose(a)), grad_out))
* We will simplify it to:
* grad_a = matmul(transpose(b), transpose(grad_out))
* grad_b = matmul(transpose(grad_out), transpose(a))
*/
#include "gradient_simplifier.h"
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/manipulate.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include "../op/tensor/linear_algebra.h"
#include "../op/tensor/manipulate.h"
namespace tvm {
namespace relax {
/*!
* \brief Simplify patterns generated by the gradient pass. Especially, simplify the matmul
* patterns.
*/
class GradientSimplifier : private ExprMutator {
public:
/*!
* \brief Collect all variables that needs to be checkpointed, and remove the start_checkpoint
* and the end_checkpoint bindings.
*
* \param func The original function
* \return The function with all start_checkpoint and end_checkpoint bindings removed, and a
* VarIdSet containing all checkpointed vars.
*/
static Function Transform(const Function& func) {
return Downcast<Function>(RemoveAllUnused(GradientSimplifier().VisitExpr(func)));
}
private:
static bool IsTransposeOp(const CallNode* call_node) {
if (call_node->op != Op::Get("relax.permute_dims")) {
return false;
}
auto sinfo = MatchStructInfo<TensorStructInfo>(call_node->args[0]);
if (!sinfo) {
return false;
}
auto ndim = sinfo.value()->ndim;
if (ndim == kUnknownNDim || ndim == 1) {
return false;
}
if (!call_node->attrs.as<PermuteDimsAttrs>()->axes.defined()) {
return ndim == 2;
}
auto axes = call_node->attrs.as<PermuteDimsAttrs>()->axes.value();
TVM_FFI_ICHECK(static_cast<int>(axes.size()) == ndim);
for (int i = 0; i < ndim - 2; ++i) {
if (axes[i] != i) {
return false;
}
}
return axes[ndim - 2] == ndim - 1 && axes[ndim - 1] == ndim - 2;
}
// Return permute_dims(expr). Generate the axes needed.
static Expr GetTransposeOf(const Expr& expr) {
auto sinfo = MatchStructInfo<TensorStructInfo>(expr);
TVM_FFI_ICHECK(sinfo);
auto ndim = sinfo.value()->ndim;
if (ndim == 1) {
return expr;
}
auto axes = ffi::Array<Integer>();
for (int i = 0; i < ndim - 2; ++i) {
axes.push_back(i);
}
axes.push_back(ndim - 1);
axes.push_back(ndim - 2);
return permute_dims(expr, axes);
}
// If expr is already in the form of permute_dims in previous bindings, return the input of the
// permute_dims op
// Else, return permute_dims(expr)
Expr GetTransposeAccordingToCtx(const Expr& expr) {
if (!expr->IsInstance<VarNode>()) {
return GetTransposeOf(expr);
}
auto prev_expr = builder_->LookupBinding(Downcast<Var>(expr));
if (!prev_expr || !prev_expr->IsInstance<CallNode>()) {
return GetTransposeOf(expr);
}
auto prev_call_node = prev_expr.as<CallNode>();
if (!IsTransposeOp(prev_call_node)) {
return GetTransposeOf(expr);
}
return prev_call_node->args[0];
}
void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) {
auto result = ExprMutator::VisitExpr(ffi::GetRef<Expr>(call_node));
auto new_call_node = result.as<CallNode>();
auto reemit_and_return = [&]() {
ReEmitBinding(binding, result);
return;
};
if (!IsTransposeOp(new_call_node)) {
return reemit_and_return();
}
auto arg = new_call_node->args[0];
if (!arg->IsInstance<VarNode>()) {
return reemit_and_return();
}
auto prev_expr = builder_->LookupBinding(Downcast<Var>(arg));
if (!prev_expr || !prev_expr->IsInstance<CallNode>()) {
return reemit_and_return();
}
auto prev_call_node = prev_expr.as<CallNode>();
if (IsTransposeOp(prev_call_node)) {
// rewrite rule #1: permute_dims(permute_dims(a)) -> a
if (prev_call_node->args[0]->IsInstance<VarNode>()) {
var_remap_[binding->var->vid] = Downcast<Var>(prev_call_node->args[0]);
return;
} else {
return reemit_and_return();
}
} else if (prev_call_node->op == Op::Get("relax.matmul")) {
// rewrite rule #2: permute_dims(matmul(a, b)) -> matmul(permute_dims(b), permute_dims(a))
// Should "a" or "b" already be in the form of "permute_dims", the redundant permute_dims
// operation should be eliminated
// Skip matmuls with 1-dim input because in these cases we cannot simply transpose the input
auto a_dim = MatchStructInfo<TensorStructInfo>(prev_call_node->args[0]).value()->ndim;
auto b_dim = MatchStructInfo<TensorStructInfo>(prev_call_node->args[1]).value()->ndim;
if (a_dim == 1 || b_dim == 1) {
return reemit_and_return();
}
auto a = GetTransposeAccordingToCtx(prev_call_node->args[0]);
auto b = GetTransposeAccordingToCtx(prev_call_node->args[1]);
result =
ExprMutator::VisitExpr(matmul(b, a, prev_call_node->attrs.as<MatmulAttrs>()->out_dtype));
ReEmitBinding(binding, result);
return;
} else {
return reemit_and_return();
}
}
};
Function SimplifyGradient(const Function& func) { return GradientSimplifier::Transform(func); }
} // namespace relax
} // namespace tvm