| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * Copyright (c) 2019 by Contributors |
| * \file eliminate_common_expr.cc |
| * \brief Eliminate common expressions in the graph |
| * \author Przemyslaw Tredak |
| */ |
| |
| #include <mxnet/base.h> |
| #include <mxnet/op_attr_types.h> |
| |
| #include <vector> |
| #include <map> |
| #include <utility> |
| #include <sstream> |
| |
| namespace mxnet { |
| namespace exec { |
| |
| namespace { |
| |
| using nnvm::Node; |
| using nnvm::NodePtr; |
| using nnvm::Graph; |
| using nnvm::IndexedGraph; |
| |
| // NodeInput holds the sufficient subset of NodeEntry fields for Node-input equality tests |
| using NodeInput = std::pair<const Node*, uint32_t>; |
| |
| /*! |
| * \brief Convert a Node's input vector of `NodeEntry` to a vector of the simpler `NodeInput` |
| */ |
| std::vector<NodeInput> ConvertInputs(const std::vector<nnvm::NodeEntry>& inputs) { |
| std::vector<NodeInput> ret; |
| for (const auto& entry : inputs) { |
| ret.emplace_back(entry.node.get(), entry.index); |
| } |
| return ret; |
| } |
| |
| /*! |
| * \brief Determine if two Nodes have equal function such that one Node can be eliminated. |
| */ |
| bool NodeEqual(const Node* n, const Node* m) { |
| if (n->is_variable() || m->is_variable()) return false; |
| if (n->op() != m->op()) return false; |
| // Nodes with different attributes are considered not identical, |
| // though this may reject Node pairs that are in fact functionally the same. |
| if (n->attrs.dict != m->attrs.dict) return false; |
| |
| // Ops that mutate inputs cannot be optimized out |
| static auto& fmutate_inputs = Op::GetAttr<nnvm::FMutateInputs>("FMutateInputs"); |
| if (fmutate_inputs.get(n->op(), nullptr) != nullptr) return false; |
| |
| // Stateful ops cannot be be equal to each other |
| static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState"); |
| if (fstateful.get(n->op(), nullptr) != nullptr) |
| return false; |
| |
| // Check to see if the user has explicitly set THasDeterministicOutput to override the |
| // subsequent determination of Node equality based on resource use. |
| static auto& deterministic_output = |
| Op::GetAttr<THasDeterministicOutput>("THasDeterministicOutput"); |
| if (deterministic_output.contains(n->op())) |
| return deterministic_output[n->op()]; |
| |
| // Ops that require resource could ask for |
| // random resource, so need to be explicitly marked |
| // to be eligible |
| static auto& resource_request = Op::GetAttr<FResourceRequest>("FResourceRequest"); |
| static auto& resource_request_ex = Op::GetAttr<FResourceRequestEx>("FResourceRequestEx"); |
| if (resource_request.get(n->op(), nullptr) != nullptr) return false; |
| if (resource_request_ex.get(n->op(), nullptr) != nullptr) return false; |
| |
| return true; |
| } |
| |
| // Graph traversal to create a list of pairs of identical-function nodes that can be combined. |
| std::vector<std::pair<NodePtr, NodePtr> > GetCommonNodes(const Graph& g) { |
| std::vector<std::pair<NodePtr, NodePtr> > ret; |
| // A map between a vector of inputs and those nodes that have those inputs |
| std::map<std::vector<NodeInput>, std::vector<const NodePtr*> > grouped_nodes; |
| // Traverse the graph and group the nodes by their vector of inputs |
| nnvm::DFSVisit(g.outputs, [&grouped_nodes](const NodePtr& n) { |
| if (n->inputs.size() != 0) { |
| grouped_nodes[ConvertInputs(n->inputs)].push_back(&n); |
| } |
| }); |
| // Now check for identical node ops within the node groups (having identical inputs) |
| for (const auto& pair : grouped_nodes) { |
| auto &node_group = pair.second; // Group of nodes that share the same vector of inputs |
| if (node_group.size() > 1) { |
| std::unordered_set<size_t> visited; |
| for (size_t i = 0; i < node_group.size(); ++i) { |
| if (visited.count(i)) continue; |
| for (size_t j = i + 1; j < node_group.size(); ++j) { |
| // If the two Nodes have equal function, then one Node (called the 'replaced') can |
| // be eliminated in favor of the other Node (the 'src'). |
| if (NodeEqual(node_group[i]->get(), node_group[j]->get())) { |
| visited.insert(j); |
| NodePtr src = *node_group[i]; |
| NodePtr replaced = *node_group[j]; |
| ret.emplace_back(src, replaced); |
| } |
| } |
| } |
| } |
| } |
| return ret; |
| } |
| |
| /*! |
| * \brief Do a single pass of Node elimination given pairs of identical Nodes. |
| */ |
| void EliminateCommonNodes(Graph* g, |
| const std::vector<std::pair<NodePtr, NodePtr> >& common_nodes) { |
| for (const auto &p : common_nodes) { |
| std::vector <NodePtr> nodes_to_change; |
| const NodePtr &src = p.first; |
| const NodePtr &replaced = p.second; |
| // Create a `nodes_to_change` list containing the Nodes that refer to the `replaced` Node |
| // that is targeted for elimination. |
| DFSVisit(g->outputs, [replaced, &nodes_to_change](const NodePtr &n) { |
| for (const auto &dep : n->control_deps) { |
| if (dep == replaced) { |
| nodes_to_change.push_back(n); |
| return; |
| } |
| } |
| for (const auto &inp : n->inputs) { |
| if (inp.node == replaced) { |
| nodes_to_change.push_back(n); |
| return; |
| } |
| } |
| }); |
| |
| // Change references to the `replaced` Node within the `nodes_to_change` list to be |
| // references to the equivalent `src` Node. |
| for (auto &n : nodes_to_change) { |
| for (auto &dep : n->control_deps) { |
| if (dep == replaced) { |
| dep = src; |
| } |
| } |
| for (auto &inp : n->inputs) { |
| if (inp.node == replaced) { |
| inp.node = src; |
| } |
| } |
| } |
| |
| // Add `replaced` Node control dependencies to those of the `src` Node. |
| for (const auto &n : replaced->control_deps) { |
| src->control_deps.push_back(n); |
| } |
| |
| // Change graph outputs driven by the `replaced` Node to now point to the `src` Node. |
| for (auto& out : g->outputs) { |
| if (out.node == replaced) { |
| out.node = src; |
| } |
| } |
| } |
| // Check for duplicates in outputs and |
| // insert Copy nodes as appropriate |
| const Op* copy_op = Op::Get("_copy"); |
| nnvm::NodeEntryMap<size_t> unique_outputs; |
| for (size_t i = 0; i < g->outputs.size(); ++i) { |
| auto kv = unique_outputs.find(g->outputs[i]); |
| if (kv == unique_outputs.end()) { |
| unique_outputs.emplace(g->outputs[i], 0); |
| } else { |
| NodePtr copy_node = Node::Create(); |
| std::ostringstream os; |
| os << kv->first.node->attrs.name << "_" << kv->second << "_copy"; |
| kv->second++; |
| copy_node->attrs.op = copy_op; |
| copy_node->attrs.name = os.str(); |
| copy_node->inputs.emplace_back(kv->first); |
| g->outputs[i] = nnvm::NodeEntry{copy_node, 0, 0}; |
| } |
| } |
| } |
| |
| } // namespace |
| |
| /*! |
| * \brief Simplify a graph by iteratively eliminating Nodes with identical inputs and function. |
| */ |
| nnvm::Graph EliminateCommonExpr(nnvm::Graph&& g) { |
| using nnvm::NodePtr; |
| bool keep_running = true; |
| while (keep_running) { |
| const auto& common_nodes = GetCommonNodes(g); |
| if (common_nodes.empty()) { |
| keep_running = false; |
| } else { |
| EliminateCommonNodes(&g, common_nodes); |
| } |
| } |
| return g; |
| } |
| |
| } // namespace exec |
| } // namespace mxnet |