blob: 00b9a3be3b2ee3323f42f75ba46607134add3032 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file split_args.cc
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include "../op/annotation/annotation.h"
#include "./pattern_utils.h"
namespace tvm {
namespace relay {
class ArgumentSplitter : public ExprRewriter {
public:
explicit ArgumentSplitter(int max_function_args)
: max_function_args_(max_function_args), concat_op_(Op::Get("concatenate")) {}
Expr Rewrite_(const CallNode* call, const Expr& post) final {
if (max_function_args_ < 0) return post;
if (call->op == concat_op_) {
auto tuple_node = call->args[0].as<TupleNode>();
const auto param = call->attrs.as<ConcatenateAttrs>();
int outputsNum = 1;
if (const auto* tuple_type = call->checked_type().as<TupleTypeNode>()) {
outputsNum = tuple_type->fields.size();
}
const int limit = max_function_args_ - outputsNum;
int argsNum = tuple_node->fields.size();
if (argsNum < limit) return post;
int splitNum = argsNum / limit;
splitNum = (argsNum % limit) ? splitNum + 1 : splitNum;
std::vector<Expr> splitted(splitNum);
for (int i = 0; i < splitNum; ++i) {
int startIdx = i * limit;
int argsCount = std::min(limit, argsNum - startIdx);
tvm::Array<Expr> args;
args.reserve(argsCount);
for (int j = 0; j < argsCount; ++j) {
args.push_back(tuple_node->fields[j + startIdx]);
}
Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), args);
Expr body = MakeConcatenate(new_tuple, param->axis);
splitted[i] = StopFusion(body);
}
tvm::Array<Expr> tuple_args(splitted);
Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), tuple_args);
return MakeConcatenate(new_tuple, param->axis);
}
return post;
}
private:
const int max_function_args_;
const Op& concat_op_;
};
Expr SplitArgs(const Expr& expr, int max_function_args) {
auto rewriter = ArgumentSplitter(max_function_args);
return PostOrderRewrite(expr, &rewriter);
}
namespace transform {
Pass SplitArgs(int max_function_args) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
auto r = Downcast<Function>(SplitArgs(f, max_function_args));
return m->attrs.defined() ? WithAttrs(r, {m->attrs->dict}) : r;
};
return CreateFunctionPass(pass_func, 1, "SplitArgs", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.SplitArgs").set_body_typed(SplitArgs);
} // namespace transform
} // namespace relay
} // namespace tvm