blob: a6506cbe98ec63e6a931426440854f9b740ef27d [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*! \file src/relax/transform/decompose_ops.cc */
#include <tvm/ffi/cast.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/nn.h>
#include <tvm/relax/transform.h>
#include <tvm/relax/type.h>
#include <tvm/tirx/function.h>
#include <unordered_set>
#include "utils.h"
namespace tvm {
namespace relax {
TensorType MatchTensorType(Expr data) {
auto _ty = MatchType<TensorType>(data);
TVM_FFI_ICHECK(_ty.defined()) << "Expect data to be a tensor, but get " << GetType(data);
return _ty.value();
}
Expr ExpandToMatchInput(Expr data, int ndim, ffi::Array<int64_t> axes) {
axes = GetOrderedPositiveAxes(axes, ndim);
ffi::Array<int64_t> expand_axes;
for (int i = 0, j = 0; i < ndim; ++i) {
if (j < static_cast<int>(axes.size()) && i == axes[j]) {
++j;
} else {
expand_axes.push_back(i);
}
}
return expand_dims(data, expand_axes);
}
Tuple DecomposeBatchNorm(const Call& call) {
auto attrs = call->attrs.as<BatchNormAttrs>();
TVM_FFI_ICHECK_NOTNULL(attrs);
Expr data = call->args[0];
TensorType ty = MatchTensorType(data);
Expr gamma = call->args[1];
Expr beta = call->args[2];
Expr moving_mean = ExpandToMatchInput(call->args[3], ty->ndim, {attrs->axis});
Expr moving_var = ExpandToMatchInput(call->args[4], ty->ndim, {attrs->axis});
// output = (x - mean) / sqrt(var + epsilon) * gamma + beta
Expr epsilon = MakeConstantScalar(attrs->epsilon, ty->dtype.value()->dtype);
Expr sqrt_var = sqrt(add(moving_var, epsilon));
Expr out = divide(subtract(data, moving_mean), sqrt_var);
if (attrs->scale) {
out = multiply(out, ExpandToMatchInput(gamma, ty->ndim, {attrs->axis}));
}
if (attrs->center) {
out = add(out, ExpandToMatchInput(beta, ty->ndim, {attrs->axis}));
}
return Tuple({out, call->args[3], call->args[4]});
}
Expr MutateBatchNormForTraining(Call call) {
auto attrs = call->attrs.as<BatchNormAttrs>();
TVM_FFI_ICHECK_NOTNULL(attrs);
TVM_FFI_ICHECK_EQ(call->args.size(), 5);
Expr data = call->args[0];
Expr gamma = call->args[1];
Expr beta = call->args[2];
Expr moving_mean = call->args[3];
Expr moving_var = call->args[4];
TensorType ty = MatchTensorType(data);
ffi::Array<int64_t> reduce_axes;
for (int i = 0; i < ty->ndim; ++i) {
if (i != attrs->axis) {
reduce_axes.push_back(i);
}
}
Expr data_mean = mean(data, reduce_axes, false);
Expr data_var = variance(data, reduce_axes, false);
Expr momentum = MakeConstantScalar(attrs->momentum, ty->dtype.value()->dtype);
Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, ty->dtype.value()->dtype);
Expr new_moving_mean = add(multiply(one_minus_mom, moving_mean), multiply(momentum, data_mean));
Expr new_moving_var = add(multiply(one_minus_mom, moving_var), multiply(momentum, data_var));
call.CopyOnWrite()->args = {data, gamma, beta, data_mean, data_var};
// return call;
return relax::Tuple({TupleGetItem(call, 0), new_moving_mean, new_moving_var});
}
Expr DecomposeLayerNorm(const Call& call) {
auto attrs = call->attrs.as<LayerNormAttrs>();
TVM_FFI_ICHECK_NOTNULL(attrs);
Expr data = call->args[0];
TensorType ty = MatchTensorType(data);
Expr gamma = call->args[1];
Expr beta = call->args[2];
Expr data_mean = mean(data, attrs->axes, true);
Expr data_var = variance(data, attrs->axes, true);
// output = (x - mean) / sqrt(var + epsilon) * gamma + beta
Expr epsilon = MakeConstantScalar(attrs->epsilon, ty->dtype.value()->dtype);
Expr sqrt_var = sqrt(add(data_var, epsilon));
Expr out = divide(subtract(data, data_mean), sqrt_var);
if (attrs->scale) {
out = multiply(out, gamma);
}
if (attrs->center) {
out = add(out, beta);
}
return out;
}
Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) {
TVM_FFI_ICHECK(call_node->ty.defined());
Expr expr = call_node->args[0];
const ShapeTypeNode* ty = GetTypeAs<ShapeTypeNode>(call_node);
TVM_FFI_ICHECK(ty);
// call builtin function that converts tensor to shape tuple
// TODO(@sunggg): Register operator for "vm.builtin.tensor_to_shape"
static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed");
Var call =
builder->Emit(Call(call_pure_packed_op, {ExternFunc("vm.builtin.tensor_to_shape"), expr}, {},
{ffi::GetRef<ShapeType>(ty)}));
// Operators like reshape take the output of `TensorToShape` as their output shape.
// Because TOPI expects to have such output shape in symbolic shape at least (i.e.,
// ffi::Array<PrimExpr>), we define symbolic variables and returns them as a ShapeExpr.
ffi::Array<PrimExpr> shape_var;
for (int i = 0; i < ty->ndim; i++) {
shape_var.push_back(tirx::Var("x", PrimType::Int(64)));
}
// bind symbolic variables to the shape tuple
relax::Var var("y", ShapeType(shape_var));
builder->EmitNormalized(MatchCast(var, call, ShapeType(shape_var)));
return ShapeExpr(shape_var);
}
/*! \brief Update operators that have a training-specific form
*
* Some operators, such as relax.op.batch_norm, need additional
* processing when being run for training. This mutator applies any mutations required
*/
class TrainingOperatorMutator : public ExprMutator {
private:
using ExprMutator::VisitExpr_;
Expr VisitExpr_(const CallNode* call_node) final {
Call call = VisitExprPostOrder_(call_node).as_or_throw<Call>();
if (call->op == batch_norm_op_) {
return MutateBatchNormForTraining(call);
} else if (call->op == layer_norm_op_) {
// Here we only decompose LayerNorm in training because it is more efficient as a single op.
// In the future maybe we can also remove this decomposition during training.
return DecomposeLayerNorm(call);
} else {
return call;
}
}
/* composite opeartor list */
const Op& batch_norm_op_ = Op::Get("relax.nn.batch_norm");
const Op& layer_norm_op_ = Op::Get("relax.nn.layer_norm");
};
class OpDecomposer : public ExprMutator {
private:
using ExprMutator::VisitExpr_;
Expr VisitExpr_(const CallNode* call_node) final {
Call call = VisitExprPostOrder_(call_node).as_or_throw<Call>();
if (call->op == batch_norm_op_) {
return DecomposeBatchNorm(call);
} else if (call->op == tensor_to_shape_op_) {
return TensorToShape(call, builder_);
}
return call;
}
/* composite opeartor list */
const Op& batch_norm_op_ = Op::Get("relax.nn.batch_norm");
const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
};
namespace transform {
namespace {
/*! \brief Helper: add or remove an attribute on a BaseFunc */
BaseFunc BaseFuncWithAttr(BaseFunc func, const std::string& attr_key, Any attr_value) {
if (auto tirx = func.as<tirx::PrimFunc>()) {
return WithAttr(tirx.value(), attr_key, attr_value);
} else if (auto relax_fn = func.as<relax::Function>()) {
return WithAttr(relax_fn.value(), attr_key, attr_value);
} else {
return func;
}
}
BaseFunc BaseFuncWithoutAttr(BaseFunc func, const std::string& attr_key) {
if (auto tirx = func.as<tirx::PrimFunc>()) {
return WithoutAttr(tirx.value(), attr_key);
} else if (auto relax_fn = func.as<relax::Function>()) {
return WithoutAttr(relax_fn.value(), attr_key);
} else {
return func;
}
}
/*!
* \brief Apply a pass to a single named function within an IRModule.
*
* Replaces all other functions with dummy ExternFunc stubs so that the
* pass does not see them, then restores the original module. Uses
* exact name match (not a regex) because all in-tree callers supply a
* literal function name.
*/
Pass ApplyDecomposeToFunction(Pass pass, ffi::String func_name) {
auto pass_func = [pass, func_name](IRModule mod, PassContext) -> IRModule {
std::unordered_set<ffi::String> keep_original_version;
std::unordered_set<ffi::String> internal_functions;
IRModule subset;
for (auto [gvar, func] : mod->functions) {
if (gvar->name_hint == func_name) {
if (!func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).has_value()) {
// Mark internal functions as externally-exposed so that
// call-tracing transforms inside the pass do not remove them.
internal_functions.insert(gvar->name_hint);
func = BaseFuncWithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
}
} else {
// Replace non-target functions with stubs to keep references intact.
keep_original_version.insert(gvar->name_hint);
func = relax::ExternFunc("dummy_" + std::string(gvar->name_hint));
func->ty = gvar->ty;
}
subset->Add(gvar, func);
}
IRModule new_subset = pass(subset);
if (new_subset.same_as(subset)) {
return mod;
}
auto write_ptr = mod.CopyOnWrite();
for (auto [gvar, func] : new_subset->functions) {
if (!keep_original_version.count(gvar->name_hint)) {
if (auto it = write_ptr->global_var_map_.find(gvar->name_hint);
it != write_ptr->global_var_map_.end()) {
write_ptr->Remove((*it).second);
}
if (internal_functions.count(gvar->name_hint)) {
func = BaseFuncWithoutAttr(func, tvm::attr::kGlobalSymbol);
}
write_ptr->Add(gvar, func);
}
}
return mod;
};
std::string pass_name = "ApplyDecomposeTo" + std::string(func_name);
return CreateModulePass(pass_func, 0, pass_name, {});
}
} // namespace
Pass MutateOpsForTraining() {
auto pass_func = [](Function func, IRModule, PassContext) -> Function {
TrainingOperatorMutator mutator;
return mutator(func).as_or_throw<Function>();
};
return CreateFunctionPass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"MutateOpsForTraining",
/*required=*/{});
}
Pass DecomposeOps() {
auto pass_func = [](Function func, IRModule, PassContext) -> Function {
OpDecomposer mutator;
return mutator(func).as_or_throw<Function>();
};
return CreateFunctionPass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"DecomposeOps",
/*required=*/{});
}
Pass DecomposeOpsForInference(ffi::Optional<ffi::String> func_name) {
if (func_name) {
return ApplyDecomposeToFunction(DecomposeOps(), func_name.value());
} else {
return DecomposeOps();
}
}
Pass DecomposeOpsForTraining(ffi::Optional<ffi::String> func_name) {
auto module_pass = tvm::transform::Sequential({MutateOpsForTraining(), DecomposeOps()},
"DecomposeOpsForTraining");
if (func_name) {
return ApplyDecomposeToFunction(module_pass, func_name.value());
} else {
return module_pass;
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("relax.transform.DecomposeOpsForInference", DecomposeOpsForInference)
.def("relax.transform.DecomposeOpsForTraining", DecomposeOpsForTraining);
}
} // namespace transform
} // namespace relax
} // namespace tvm