| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file simplify_inference.cc |
| */ |
| #include <tvm/relay/analysis.h> |
| #include <tvm/relay/attrs/nn.h> |
| #include <tvm/relay/expr_functor.h> |
| #include <tvm/relay/op.h> |
| #include <tvm/relay/transform.h> |
| |
| #include "pattern_utils.h" |
| |
| namespace tvm { |
| namespace relay { |
| |
| Expr BatchNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean, |
| Expr moving_var, Type tdata) { |
| auto ttype = tdata.as<TensorTypeNode>(); |
| ICHECK(ttype); |
| const auto param = attrs.as<BatchNormAttrs>(); |
| Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon)); |
| Expr var_add_eps = Add(moving_var, epsilon); |
| Expr sqrt_var = Sqrt(var_add_eps); |
| Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var); |
| |
| if (param->scale) { |
| scale = Multiply(scale, gamma); |
| } |
| Expr neg_mean = Negative(moving_mean); |
| Expr shift = Multiply(neg_mean, scale); |
| if (param->center) { |
| shift = Add(shift, beta); |
| } |
| |
| auto ndim = ttype->shape.size(); |
| int axis = (param->axis < 0) ? param->axis + ndim : param->axis; |
| scale = ExpandBiasToMatchAxis(scale, ndim, {axis}); |
| shift = ExpandBiasToMatchAxis(shift, ndim, {axis}); |
| |
| Expr out = Multiply(data, scale); |
| out = Add(out, shift); |
| return out; |
| } |
| |
| Expr GroupNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { |
| auto ttype = tdata.as<TensorTypeNode>(); |
| ICHECK(ttype); |
| const auto param = attrs.as<GroupNormAttrs>(); |
| ICHECK(param); |
| |
| int ndim = ttype->shape.size(); |
| int axis = (param->axis < 0) ? param->axis + ndim : param->axis; |
| Array<Integer> reduced_axes; |
| Array<Integer> new_shape; |
| Array<Integer> old_shape; |
| |
| int num_groups = param->num_groups; |
| int channel = ttype->shape[axis].as<IntImmNode>()->value; |
| |
| // old_shape = N, C, H, W |
| // new shape = N, num_groups, C/num_groups, H, W |
| // reduce_axes = axis of (C/num_groups, H, W) |
| for (int i = 0; i < ndim; ++i) { |
| auto val = ttype->shape[i].as<IntImmNode>()->value; |
| |
| // Save the old shape to reshape later |
| old_shape.push_back(val); |
| if (i == axis) { |
| new_shape.push_back(num_groups); |
| new_shape.push_back(channel / num_groups); |
| reduced_axes.push_back(i + 1); |
| continue; |
| } |
| if (i >= axis) { |
| reduced_axes.push_back(i + 1); |
| } |
| new_shape.push_back(val); |
| } |
| |
| data = Reshape(data, new_shape); |
| |
| Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon)); |
| Expr mean = Mean(data, {reduced_axes}, true, false); |
| Expr var = Variance(data, mean, {reduced_axes}, true, false); |
| Expr denom = Sqrt(Add(var, epsilon)); |
| Expr out = Divide(Subtract(data, mean), denom); |
| |
| out = Reshape(out, old_shape); |
| |
| if (param->scale) { |
| out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis})); |
| } |
| if (param->center) { |
| out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis})); |
| } |
| |
| return out; |
| } |
| |
| Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { |
| auto ttype = tdata.as<TensorTypeNode>(); |
| ICHECK(ttype); |
| const auto param = attrs.as<LayerNormAttrs>(); |
| ICHECK(param); |
| |
| Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon)); |
| Expr mean = Mean(data, {param->axis}, true, false); |
| Expr var = Variance(data, mean, {param->axis}, true, false); |
| Expr denom = Sqrt(Add(var, epsilon)); |
| Expr out = Divide(Subtract(data, mean), denom); |
| |
| size_t ndim = ttype->shape.size(); |
| int axis = (param->axis < 0) ? param->axis + ndim : param->axis; |
| if (param->scale) { |
| out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis})); |
| } |
| if (param->center) { |
| out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis})); |
| } |
| return out; |
| } |
| |
| Expr InstanceNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { |
| auto ttype = tdata.as<TensorTypeNode>(); |
| ICHECK(ttype); |
| const auto param = attrs.as<InstanceNormAttrs>(); |
| ICHECK(param); |
| |
| int ndim = ttype->shape.size(); |
| int axis = (param->axis < 0) ? param->axis + ndim : param->axis; |
| Array<Integer> reduced_axes; |
| for (int i = 1; i < ndim; ++i) { |
| if (i != axis) reduced_axes.push_back(i); |
| } |
| |
| Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon)); |
| Expr mean = Mean(data, reduced_axes, true, false); |
| Expr var = Variance(data, mean, reduced_axes, true, false); |
| Expr denom = Sqrt(Add(var, epsilon)); |
| Expr out = Divide(Subtract(data, mean), denom); |
| |
| if (param->scale) { |
| out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis})); |
| } |
| if (param->center) { |
| out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis})); |
| } |
| return out; |
| } |
| |
| Expr L2NormToInferUnpack(const Attrs attrs, Expr data) { |
| const auto param = attrs.as<L2NormalizeAttrs>(); |
| ICHECK(param); |
| |
| Expr epsilon = MakeConstantScalar(DataType::Float(32), static_cast<float>(param->eps)); |
| |
| Expr sqr = Multiply(data, data); |
| Expr sum = Maximum(Sum(sqr, param->axis, true, false), epsilon); |
| Expr sqrt = Sqrt(sum); |
| return Divide(data, sqrt); |
| } |
| |
| class InferenceSimplifier : public MixedModeMutator { |
| public: |
| InferenceSimplifier() |
| : batch_norm_op_(Op::Get("nn.batch_norm")), |
| dropout_op_(Op::Get("nn.dropout")), |
| instance_norm_op_(Op::Get("nn.instance_norm")), |
| layer_norm_op_(Op::Get("nn.layer_norm")), |
| group_norm_op_(Op::Get("nn.group_norm")), |
| l2_norm_op_(Op::Get("nn.l2_normalize")) {} |
| |
| Expr Rewrite_(const TupleGetItemNode* n, const Expr& new_e) final { |
| const auto* new_n = new_e.as<TupleGetItemNode>(); |
| if (new_n->index != 0) { |
| return new_e; |
| } |
| if (const auto* call = new_n->tuple.as<CallNode>()) { |
| if (call->op == batch_norm_op_) { |
| return BatchNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], |
| call->args[3], call->args[4], ty_map_.at(call->args[0])); |
| } else if (call->op == dropout_op_) { |
| return call->args[0]; |
| } |
| } |
| return new_e; |
| } |
| |
| Expr Rewrite_(const CallNode* n, const Expr& new_n) { |
| if (n->op == batch_norm_op_) { |
| ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type(); |
| } else if (n->op == layer_norm_op_) { |
| const auto* call = new_n.as<CallNode>(); |
| return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], |
| n->args[0]->checked_type()); |
| } else if (n->op == group_norm_op_) { |
| const auto* call = new_n.as<CallNode>(); |
| return GroupNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], |
| n->args[0]->checked_type()); |
| } else if (n->op == instance_norm_op_) { |
| const auto* call = new_n.as<CallNode>(); |
| return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], |
| n->args[0]->checked_type()); |
| } else if (n->op == l2_norm_op_) { |
| const auto* call = new_n.as<CallNode>(); |
| return L2NormToInferUnpack(call->attrs, call->args[0]); |
| } |
| return new_n; |
| } |
| |
| private: |
| // Cache the following ops. They will be used in the passes repeatedly for |
| // operator equivalence checking so that the registry lookup overhead can be |
| // reduced. |
| const Op& batch_norm_op_; |
| const Op& dropout_op_; |
| const Op& instance_norm_op_; |
| const Op& layer_norm_op_; |
| const Op& group_norm_op_; |
| const Op& l2_norm_op_; |
| std::unordered_map<Expr, Type, ObjectPtrHash, ObjectPtrEqual> ty_map_; |
| }; |
| |
| Expr SimplifyInference(const Expr& e) { return InferenceSimplifier().Mutate(e); } |
| |
| namespace transform { |
| |
| Pass SimplifyInference() { |
| runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
| [=](Function f, IRModule m, PassContext pc) { |
| return Downcast<Function>(SimplifyInference(f)); |
| }; |
| return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"}); |
| } |
| |
| TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference").set_body_typed(SimplifyInference); |
| |
| } // namespace transform |
| |
| } // namespace relay |
| } // namespace tvm |