blob: 76b26d0e085b281340d4c3d08b7c3ea536e145c9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file combine_parallel_dense.cc
* \brief Combine parallel dense ops into a single dense.
*
* This pass replaces dense ops that share the same input node, same shape,
* and don't have "units" defined with a single batch matrix multiplication.
* The inputs of the new batch_matmul is the stack of the original inputs.
* Elemwise and broadcast ops following dense are also combined if possible.
*
* This prevents launching multiple kernels in networks with multiple
* dense branches, such as BERT.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "./combine_parallel_op_batch.h"
#include "./expr_subst.h"
#include "pattern_util.h"
namespace tvm {
namespace relay {
/*
* Class that find and combine parallel dense ops into batch_matmul.
*/
class ParallelDenseToBatchCombiner : public ParallelOpBatchCombiner {
public:
explicit ParallelDenseToBatchCombiner(uint64_t min_num_branches)
: ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {}
protected:
virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
StructuralEqual eq;
const auto* attrs_a = a->attrs.as<DenseAttrs>();
const auto* attrs_b = b->attrs.as<DenseAttrs>();
CHECK(attrs_a);
CHECK(attrs_b);
const auto* weight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* weight_b = b->args[1]->type_as<TensorTypeNode>();
return eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
eq(weight_a->shape[0], weight_b->shape[0]) && eq(weight_a->shape[1], weight_b->shape[1]);
}
};
/*
* Class that find and combine parallel dense ops into one dense op
* whose num of output units equals to sum of each sub-ops.
*/
class ParallelDenseToDenseCombiner : public ParallelOpCombiner {
public:
explicit ParallelDenseToDenseCombiner(uint64_t min_num_branches)
: ParallelOpCombiner("nn.dense", min_num_branches) {}
protected:
bool IsSupportedOp(const CallNode* n) { return true; }
bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
StructuralEqual eq;
const auto* attrs_a = a->attrs.as<DenseAttrs>();
const auto* attrs_b = b->attrs.as<DenseAttrs>();
const auto* weight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* weight_b = b->args[1]->type_as<TensorTypeNode>();
CHECK(attrs_a != nullptr && attrs_b != nullptr && weight_a != nullptr && weight_b != nullptr);
// output dims (weight->shape[0]) can be different
return eq(attrs_a->out_dtype, attrs_b->out_dtype) && eq(weight_a->shape[1], weight_b->shape[1]);
}
Call MakeCombinedOp(const Group& branches) {
const Op& dense_op = Op::Get("nn.dense");
Expr input = branches[0][0]->args[0];
Expr new_weight;
IndexExpr new_output_dims;
// concat all weights into one
std::tie(new_weight, new_output_dims) = TransformWeight(branches);
const auto* origin_attrs = branches[0][0]->attrs.as<DenseAttrs>();
CHECK(origin_attrs);
const auto dense_attrs = make_object<DenseAttrs>();
dense_attrs->units = new_output_dims;
dense_attrs->out_dtype = origin_attrs->out_dtype;
return Call(dense_op, {input, new_weight}, Attrs{dense_attrs}, {});
}
bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
StructuralEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
auto toutput_a = a->type_as<TensorTypeNode>();
auto toutput_b = b->type_as<TensorTypeNode>();
CHECK(ta != nullptr && tb != nullptr && toutput_a != nullptr && toutput_b != nullptr);
if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) {
return false;
}
if (toutput_a->shape.size() < ta->shape.size() || toutput_b->shape.size() < tb->shape.size()) {
return false; // not broadcast/elemwise
}
if (ta->shape.size() > 0) {
for (size_t i = 0; i < ta->shape.size() - 1; i++) {
// shape dims must match except last dim
if (!eq(ta->shape[i], tb->shape[i])) return false;
}
}
return true;
}
Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth,
size_t parent_index) {
Array<Expr> new_args;
const CallNode* call = branches[0][depth];
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) {
new_args.push_back(data);
continue;
}
size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size();
size_t concat_axis = arg_ndim == 0 ? 0 : arg_ndim - 1;
Array<Expr> tuple;
for (const auto& branch : branches) {
auto parent = branch[depth]->args[parent_index];
auto& parent_shape = parent->type_as<TensorTypeNode>()->shape;
auto out_dim = tir::as_const_int(parent_shape[parent_shape.size() - 1]);
CHECK(out_dim != nullptr);
auto arg = branch[depth]->args[i];
auto& arg_shape = arg->type_as<TensorTypeNode>()->shape;
bool repeat_last_dim = false;
if (arg_ndim == 0) {
repeat_last_dim = true;
arg = MakeExpandDims(arg, -1, 1);
} else {
auto arg_last_dim = tir::as_const_int(arg_shape[arg_shape.size() - 1]);
CHECK(arg_last_dim != nullptr);
if (*out_dim > 1 && *arg_last_dim == 1) {
repeat_last_dim = true;
}
}
if (repeat_last_dim) {
// ensure broadcast is valid after concat args
arg = MakeRepeat(arg, *out_dim, concat_axis);
}
tuple.push_back(arg);
}
auto concat = MakeConcatenate(Tuple(tuple), concat_axis);
new_args.push_back(std::move(concat));
}
return Call(call->op, new_args, call->attrs, {});
}
void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
ExprSubstMap* subst_map) {
int index = 0;
for (const auto& branch : branches) {
const CallNode* call = branch[depth];
auto& out_shape = call->type_as<TensorTypeNode>()->shape;
auto out_dims = tir::as_const_int(out_shape[out_shape.size() - 1]);
CHECK(out_dims != nullptr);
Array<Integer> begin;
Array<Integer> end;
Array<Integer> strides;
for (size_t k = 0; k < out_shape.size() - 1; ++k) {
begin.push_back(0);
end.push_back(-1);
strides.push_back(1);
}
begin.push_back(index);
end.push_back(*out_dims);
strides.push_back(1);
index += *out_dims;
auto slice = MakeStridedSlice(data, begin, end, strides, "size");
subst_map->insert({GetRef<Expr>(branch[depth]), slice});
}
}
private:
std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
int64_t out_dims = 0;
Array<Expr> weights;
for (const auto& branch : branches) {
auto weight = branch[0]->args[1];
weights.push_back(weight);
out_dims += *tir::as_const_int(weight->type_as<TensorTypeNode>()->shape[0]);
}
return std::make_tuple(MakeConcatenate(Tuple(weights), 0),
tir::make_const(DataType::Int(32), out_dims));
}
};
/*! \brief Combine parallel dense if number of branches >= min_num_branches */
Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches, bool to_batch) {
if (to_batch) {
return ParallelDenseToBatchCombiner(min_num_branches).Combine(expr);
} else {
return ParallelDenseToDenseCombiner(min_num_branches).Combine(expr);
}
}
namespace transform {
Pass CombineParallelDense(uint64_t min_num_branches, bool to_batch_matmul) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelDense(f, min_num_branches, to_batch_matmul));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense").set_body_typed(CombineParallelDense);
} // namespace transform
} // namespace relay
} // namespace tvm