blob: ddab87a4893e1d0d4c88901ae446ed336094bd6b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file combine_parallel_batch_matmul.cc
* \brief Combine parallel batch matmuls into a single one.
*
* This pass replaces batch_matmul that share the same lhs node with a
* single batch matmul.Elemwise and broadcast ops following batch_matmul are also
* combined if possible.
*
* This prevents launching multiple kernels in networks with multiple
* convolution branches, such as Inception block.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "./combine_parallel_op.h"
#include "./expr_subst.h"
#include "pattern_utils.h"
namespace tvm {
namespace relay {
class ParallelBatchMatmulCombiner : public ParallelOpCombiner {
public:
explicit ParallelBatchMatmulCombiner(uint64_t min_num_branches)
: ParallelOpCombiner("nn.batch_matmul", min_num_branches) {}
protected:
bool IsSupportedOp(const CallNode* n) { return true; }
bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
StructuralEqual eq;
const auto* attrs_a = a->attrs.as<BatchMatmulAttrs>();
const auto* attrs_b = b->attrs.as<BatchMatmulAttrs>();
ICHECK(attrs_a);
ICHECK(attrs_b);
const auto* rhs_a = a->args[1]->type_as<TensorTypeNode>();
const auto* rhs_b = b->args[1]->type_as<TensorTypeNode>();
const auto* restype_a = a->type_as<TensorTypeNode>();
const auto* restype_b = b->type_as<TensorTypeNode>();
// shape[2] is the contraction axis and automatically consistent
// if it were valid batch_matmul ops
// TODO(jcf94): Add full support of layout format
if (!(attrs_a->transpose_a == false && attrs_a->transpose_b == true &&
attrs_b->transpose_a == false && attrs_b->transpose_b == true)) {
LOG(WARNING) << "For legacy reason, this pass only supports"
<< " (transpose_a=false, transpose_b=true) now, skip combining these two with:"
<< " batch_matmul_a: " << attrs_a->transpose_a << ", " << attrs_a->transpose_b
<< " batch_matmul_b: " << attrs_b->transpose_a << ", " << attrs_b->transpose_b;
return false;
}
auto res = eq(rhs_a->dtype, rhs_b->dtype) && eq(restype_a->dtype, restype_b->dtype) &&
(rhs_a->shape.size() == 3) && (rhs_b->shape.size() == 3) &&
eq(rhs_a->shape[0], rhs_b->shape[0]) && eq(attrs_a->out_dtype, attrs_b->out_dtype);
return res;
}
Call MakeCombinedOp(const Group& branches) {
Expr data = branches[0][0]->args[0];
Array<Expr> weights;
for (const auto& branch : branches) {
auto call = branch[0];
weights.push_back(call->args[1]);
}
Expr new_weight = MakeConcatenate(Tuple(weights), 1);
const auto* origin_attrs = branches[0][0]->attrs.as<BatchMatmulAttrs>();
ICHECK(origin_attrs);
return Downcast<Call>(MakeBatchMatmul(data, new_weight, origin_attrs->out_dtype,
origin_attrs->transpose_a, origin_attrs->transpose_b));
}
bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { return true; }
Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth,
size_t parent_index) {
Array<Expr> new_args;
const CallNode* call = branches[0][depth];
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) {
new_args.push_back(data);
continue;
}
Array<Expr> tuple;
for (const auto& branch : branches) {
tuple.push_back(branch[depth]->args[i]);
}
auto concat = MakeConcatenate(Tuple(tuple), -1);
new_args.push_back(std::move(concat));
}
return Call(call->op, new_args, call->attrs, {});
}
void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
ExprSubstMap* subst_map) {
int64_t index = 0;
for (const auto& branch : branches) {
const CallNode* batch_matmul = branch[0];
auto feature_dim = batch_matmul->args[1]->type_as<TensorTypeNode>()->shape[1];
auto fpp = tir::as_const_int(feature_dim);
int64_t features = *fpp;
Array<Integer> begin;
Array<Integer> end;
for (size_t i = 0; i < 2; i++) {
begin.push_back(0);
end.push_back(-1);
}
begin.push_back(index);
index += features;
end.push_back(features);
Array<Integer> strides(begin.size(), 1);
auto slice = MakeStridedSlice(data, begin, end, strides, "size");
subst_map->insert({GetRef<Expr>(branch[depth]), slice});
}
}
};
/*! \brief Combine parallel batch_matmul if number of branches >= min_num_branches */
Expr CombineParallelBatchMatmul(const Expr& expr, uint64_t min_num_branches) {
return ParallelBatchMatmulCombiner(min_num_branches).Combine(expr);
}
namespace transform {
Pass CombineParallelBatchMatmul(uint64_t min_num_branches) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelBatchMatmul(f, min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelBatchMatmul", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelBatchMatmul")
.set_body_typed(CombineParallelBatchMatmul);
} // namespace transform
} // namespace relay
} // namespace tvm