blob: 6f53e86d534b061fb936b036aad5babd475a7758 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file combine_parallel_op.h
* \brief Abstract class to combine parallel ops and their successive element-wise ops.
*/
#ifndef TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_
#define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "./expr_subst.h"
#include "pattern_util.h"
namespace tvm {
namespace relay {
using Branch = std::vector<const CallNode*>;
using Group = std::vector<Branch>;
using FIsSupportedOp = std::function<bool(const CallNode* n)>;
using FAreCompatibleOps = std::function<bool(const CallNode* a, const CallNode* b)>;
using ExprSubstMap = std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>;
/*
* Class to find parallel branches starting with op that are
* grouped if they are able to be combined. They are eligible to
* be combined if they have the same input data.
* Op can be followed by zero or more elemwise or broadcast ops,
* which are included in the group.
* Intermediate nodes have exactly one successor. It is possible that branches meet at a point,
* which should be handled in ParallelOpCombiner.
*
* data
* / \
* op op
* | |
* elem-wise elem-wise
* | |
*/
class BranchGroupFinder : private ExprVisitor {
public:
/*
* \brief Constructor
* \param op The op that indicates the start of each group
* \param fis_supported_op function that returns true if op
* is supported for combining
* \param fare_compatible_ops function that returns true if
* two ops are compatible for combining
*/
BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops);
/*
* \brief Finds all groups that can be combined.
* \param expr Relay expression that represents function
* to look at for groups to be combined
* \return Vector of groups which can be combined.
*/
std::vector<Group> Find(const Expr& expr);
private:
/* \brief Cache the op for finding parallel branches */
const Op& cached_op_;
/* \brief function to return true if op is eligible to be combined,
* false otherwise
*/
FIsSupportedOp fis_supported_op_;
/* \brief function to return true if two parallel ops are eligible
* to be combined, false otherwise
*/
FAreCompatibleOps fare_compatible_ops_;
/* \brief ops that are on the first (logically, leftmost) branch
* of parallel ops and are eligible to be combined
*/
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> op_roots_;
/* \brief map of Expr to CallNodes that follow it */
std::unordered_map<Expr, std::vector<const CallNode*>, ObjectPtrHash, ObjectPtrEqual>
children_map_;
/*
* \brief Creates new branch from op and its children that have
* elementwise or broadcast patterns
* \return New branch
*/
Branch CreateBranch(const CallNode* op);
/*
* \brief Expression visitor function
*/
void VisitExpr_(const CallNode* n) final;
};
/*
* Abstract class to find and combine parallel ops and the elementwise ops that follow.
*/
class ParallelOpCombiner {
public:
/*! \brief virtual destructor */
virtual ~ParallelOpCombiner() {}
/*
* \brief Constructor.
* \param op_name name of op to combine
* \param min_num_branches min number of parallel branches beginning with op
* to start combining
*/
explicit ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches);
/*
* \brief Combines ops and following elementwise or broadcast ops
* \param expr function to modify
* \return new function with combined ops
*/
Expr Combine(const Expr& expr);
protected:
/*
* \brief Checks if node is supported to be combined
* \param n node in question
* \return True if the op represented by n is supported to be the root of a branch
* to be combined. False otherwise.
*/
virtual bool IsSupportedOp(const CallNode* n) = 0;
/*
* \brief Checks if two ops can be combined
* \param a node a
* \param b node b
* \return True if a and b can be combined. False otherwise.
*/
virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) = 0;
/*
* \brief Makes combined op from parallel ops in branches. This usually involves
* concatenating or stacking inputs, then creating a new call.
* \param branches branches that are to be combined
* \return new call with branches combined.
*/
virtual Call MakeCombinedOp(const Group& branches) = 0;
/*
* \brief Checks if argument of op following combined ops are able to be combined
* \param a node a
* \param b node b
* \param index index of argument in question
* \return True if argument of a and b and index can be combined
*/
virtual bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) = 0;
/*
* \brief Create combined call from ops that follow the initial combined op at the depth-th level.
* This usually involves concatenating or stacking inputs, then creating a new call.
* Only called if IsArgCompatbile returns true for each arg.
* \param data combined op
* \param branches branches of parallel ops to be combined
* \param depth depth at which to combine ops
* \param parent_index index of arg that corresponds to original input that was shared among
* all combined ops
* \return new combined call
*/
virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches,
size_t depth, size_t parent_index) = 0;
/*
* \brief Updates map of expr to substitute with combined expr. This usually involves
* slicing or splitting data.
* \param data combined op
* \param branches branches of parallel ops to be combined
* \param depth depth at which to substitute
* \param subst_map map of Expr to replace with Expr to replace it with
*/
virtual void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
ExprSubstMap* subst_map) = 0;
private:
/* \brief Cache the op to be combined */
const Op& cached_op_;
/* \brief minimum number of parallel branches to combine */
uint64_t min_num_branches_;
/* \brief map of Expr to Expr to substitute it with after running pass */
ExprSubstMap subst_map_;
/*
* \brief Combine parallel branches and updates subst_map_ with Exprs
* to be substituted
* \param branches branches to be combined
*/
void CombineBranches(const Group& branches);
/*
* \brief Combine parallel branches and updates subst_map_ with Exprs
* to be substituted
* \param branches parallel branches to potentially be combined
* \param depth depth at which to look at op
* \param parent_index index of arg that corresponds to original input that was shared among
* all combined ops
* \return true if parallel ops at depth can be combined, false otherwise
*/
bool CheckLevel(const Group& branches, size_t depth, size_t parent_index);
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_