blob: 32ea341d083472de1272fdc37c3013d78941450d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file build_subgraph.cc
* \brief
*/
#include <nnvm/graph.h>
#include <nnvm/pass.h>
#include <mxnet/op_attr_types.h>
#include <unordered_set>
#include <stack>
#include <queue>
#include "./subgraph_property.h"
#define DEBUG_SUBGRAPH 0
namespace nnvm {
NodePtr CreateVariableNode(const std::string& name);
}
namespace mxnet {
namespace op {
namespace sg { // sg stands for subgraph
#if DEBUG_SUBGRAPH
void PrintSubgraph(const std::vector<BiDirectedNode*>& simple_nodes) {
std::string op_names = "";
for (size_t i = 0; i < simple_nodes.size(); ++i) {
op_names += simple_nodes[i]->node->attrs.name + ' ';
}
LOG(INFO) << "Subgraph node names: " << op_names;
}
void PrintNodeEntry(const nnvm::NodeEntry& entry) {
std::string ret = "NodeEntry: node_name=" + entry.node->attrs.name
+ ", index=" + std::to_string(entry.index) + ", version=" + std::to_string(entry.version);
LOG(INFO) << ret;
}
void PrintNodeEntries(const std::vector<nnvm::NodeEntry*>& entries) {
for (size_t i = 0; i < entries.size(); ++i) {
PrintNodeEntry(*entries[i]);
}
}
#endif
/*!
* \brief Given a MXNet computational graph, create an undirected graph from it.
* \param g the MXNet computational graph
* \param simple_nodes the nodes of undirected graph in top sorted order
*/
void CreateSimpleGraph(const nnvm::Graph& g,
std::vector<BiDirectedNodePtr>* simple_nodes) {
const auto& indexed_graph = g.indexed_graph();
simple_nodes->reserve(indexed_graph.num_nodes());
DFSVisit(g.outputs, [&](const nnvm::NodePtr& node) {
BiDirectedNodePtr sn = BiDirectedNode::Create();
sn->node = node.get();
for (size_t i = 0; i < sn->node->inputs.size(); ++i) {
const auto& e = sn->node->inputs[i];
const auto input_nid = indexed_graph.node_id(e.node.get());
CHECK_LT(input_nid, simple_nodes->size());
auto& input_node_outputs = (*simple_nodes)[input_nid]->outputs;
auto it = input_node_outputs.find(sn->node);
if (it == input_node_outputs.end()) {
input_node_outputs.emplace(sn->node, std::vector<size_t>{i});
} else {
it->second.push_back(i);
}
}
simple_nodes->emplace_back(std::move(sn));
});
}
/*!
* \brief Reset labels of the subgraph nodes to the original state
* and clear the vector of subgraph nodes.
*/
void ResetNodeLabels(const nnvm::Graph& g,
const std::vector<BiDirectedNodePtr>& simple_nodes,
std::vector<BiDirectedNode*>* subgraph_nodes) {
for (auto n : *subgraph_nodes) {
const auto nid = g.indexed_graph().node_id(n->node);
simple_nodes[nid]->label = -1;
}
subgraph_nodes->clear();
}
/*!
* \brief This function traverses the nodes in a computation graph from a starting
* node following the input edges and output edges, and marks all nodes that
* can be accessed from the starting node. Before the function returns,
* it will conduct checking whether there is a loop between the potential subgraph
* and the outside nodes. If so, add the node that should break the loop
* in excluded_nodes and return false. Otherwise, return true.
* \param g the whole graph
* \subgraph_selector determines whether the visited node should be choosen or not
* \label the label of the current subgraph
* \snid node id of the seed simple node
* \simple_nodes all simple nodes in the top sorted order
* \subgraph_nodes all the nodes belonging to the same subgraph of seed node
* \excluded_nodes set of nodes that should be excluded from the current subgraph
*/
bool LabelSubgraph(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector, const int label,
const size_t snid, const std::vector<BiDirectedNodePtr>& simple_nodes,
std::vector<BiDirectedNode*>* subgraph_nodes,
std::unordered_set<const BiDirectedNode*>* excluded_nodes) {
const auto& indexed_graph = g.indexed_graph();
std::queue<BiDirectedNode*> node_queue;
CHECK_EQ(simple_nodes[snid]->label, -1);
simple_nodes[snid]->label = label;
node_queue.push(simple_nodes[snid].get());
// key: nodes that serve as input/output nodes to the subgraph
// value: pair of vectors of nodes in the subgraph. The first vector contains the
// output nodes of the key in the subgraph, and the second vector contains the
// input nodes of the key in the subgraph.
// If a non-subgraph node has inputs from the subgraph and the other non-subgraph node
// has outputs to the subgraph, and the first non-subgraph node is an ancestor
// of the second non-subgraph node, there exits a cycle.
// When breaking the cycle, we want to start from removing the node with the largest node id
// in the subgraph.
std::unordered_map<const nnvm::Node*,
std::pair<std::vector<const nnvm::Node*>,
std::vector<const nnvm::Node*>>> non_subgraph_node_map;
while (!node_queue.empty()) {
BiDirectedNode* cur_node = node_queue.front();
node_queue.pop();
subgraph_nodes->push_back(cur_node);
// get qualified adjacent input nodes
for (auto& e : cur_node->node->inputs) {
const auto node = e.node.get();
const auto nid = indexed_graph.node_id(node);
auto snode = simple_nodes[nid].get();
CHECK_LT(nid, simple_nodes.size());
const bool select_input =
(snode->label == -1) && (!excluded_nodes || !excluded_nodes->count(snode)) &&
subgraph_selector->SelectInput(*cur_node, *snode);
if (select_input) {
// e.node is a subgraph node
snode->label = label;
node_queue.push(snode);
} else if (snode->label == -1) {
// e.node is an input node of the subgraph
non_subgraph_node_map[e.node.get()].first.push_back(cur_node->node);
}
}
// get qualified output nodes
for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) {
const auto nid = indexed_graph.node_id(it->first);
auto snode = simple_nodes[nid].get();
CHECK_LT(nid, simple_nodes.size());
const bool select_output =
(snode->label == -1) && (!excluded_nodes || !excluded_nodes->count(snode)) &&
subgraph_selector->SelectOutput(*cur_node, *snode);
if (select_output) {
// it->first is a subgraph node
snode->label = label;
node_queue.push(snode);
} else if (snode->label == -1) {
// it->first is an output node of the subgraph
non_subgraph_node_map[it->first].second.push_back(cur_node->node);
}
}
}
// prepare to check if there is a cycle
auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
};
std::vector<const nnvm::Node*> non_subgraph_nodes;
non_subgraph_nodes.reserve(non_subgraph_node_map.size());
for (auto& kv : non_subgraph_node_map) {
auto& output_nodes = kv.second.first;
std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
auto& input_nodes = kv.second.second;
std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
non_subgraph_nodes.push_back(kv.first);
}
// check whether there is a cycle between the subgraph and its input/output nodes
auto is_ancestor = [&](const nnvm::Node* ancestor, const nnvm::Node* descendant,
const std::vector<BiDirectedNode*>& snodes) {
if (ancestor == descendant) return true;
std::unordered_set<nnvm::Node*> snode_set;
for (const auto& sn : snodes) {
snode_set.insert(sn->node);
}
std::stack<const nnvm::Node*> s;
s.push(descendant);
size_t count = 0;
while (!s.empty()) {
CHECK_LT(count, indexed_graph.num_nodes()) << "Finding ancestor failed. There is probably"
" a loop in the graph";
++count;
const nnvm::Node* top = s.top();
s.pop();
if (top == ancestor) {
return true;
}
for (const auto& entry : top->inputs) {
// when searching for the ancestor, the path cannot cross any subgraph node
if (!snode_set.count(entry.node.get())) {
s.push(entry.node.get());
}
}
}
return false;
};
std::sort(non_subgraph_nodes.begin(), non_subgraph_nodes.end(), node_cmp);
int excluded_node_id = -1;
for (size_t i = 0; i < non_subgraph_nodes.size(); ++i) {
auto it1 = non_subgraph_node_map.find(non_subgraph_nodes[i]);
CHECK(it1 != non_subgraph_node_map.end());
auto& output_nodes = it1->second.first; // has been top sorted
auto& input_nodes = it1->second.second; // has been top sorted
if (!output_nodes.empty() && !input_nodes.empty()) {
// there is a loop between node i and the subgraph
const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
indexed_graph.node_id(input_nodes.back()));
excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
} else if (!input_nodes.empty()) {
// node i is an input to the subgraph, find out if there is a node j
// which is an output of the subgraph and also a child of node i.
for (size_t j = i + 1; j < non_subgraph_nodes.size(); ++j) {
auto it2 = non_subgraph_node_map.find(non_subgraph_nodes[j]);
CHECK(it2 != non_subgraph_node_map.end());
// i is topologically before j, j might be a direct/indirect output node of i
CHECK_LT(indexed_graph.node_id(it1->first), indexed_graph.node_id(it2->first));
if (!it2->second.first.empty() && is_ancestor(it1->first, it2->first, *subgraph_nodes)) {
// found a loop
const auto node_id = std::max(indexed_graph.node_id(input_nodes.back()),
indexed_graph.node_id(it2->second.first.back()));
excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
}
}
}
}
if (excluded_node_id != -1) {
CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
CHECK_NE(excluded_node_id, static_cast<int>(snid))
<< "A cycle is found in the computational graph between nodes "
<< simple_nodes[excluded_node_id]->node->attrs.name << " and "
<< simple_nodes[snid]->node->attrs.name;
excluded_nodes->insert(simple_nodes[excluded_node_id].get());
ResetNodeLabels(g, simple_nodes, subgraph_nodes);
return false;
}
auto sim_node_cmp = [&] (const BiDirectedNode* node1, const BiDirectedNode* node2) {
return indexed_graph.node_id(node1->node) < indexed_graph.node_id(node2->node);
};
std::sort(subgraph_nodes->begin(), subgraph_nodes->end(), sim_node_cmp);
return true;
}
/*!
* \brief Finds all the nodes belonging to the same subgraph given a seed node.
* \param g the whole graph
* \subgraph_selector determines whether the visited node should be choosen or not
* \label the label of the current subgraph
* \snid node id of the seed simple node
* \simple_nodes all simple nodes in the top sorted order
* \subgraph_nodes all the nodes belonging to the same subgraph of seed node
* \return Subgraph node candidates sorted in the topological order
*/
void PreSelectSubgraphNodes(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector,
const int label, const size_t snid,
const std::vector<BiDirectedNodePtr>& simple_nodes,
std::vector<BiDirectedNode*>* subgraph_nodes) {
std::unordered_set<const BiDirectedNode*> excluded_nodes;
const size_t max_num_retry = simple_nodes.size() * simple_nodes.size();
size_t count = 0;
bool success = false;
while (!success && count < max_num_retry) {
success = LabelSubgraph(g, subgraph_selector, label, snid, simple_nodes, subgraph_nodes,
&excluded_nodes);
if (!success) {
CHECK(!excluded_nodes.empty());
std::string excluded_node_names;
for (auto node : excluded_nodes) {
excluded_node_names += node->node->attrs.name + ", ";
}
LOG(INFO) << "Found a cycle when BFS from node " << simple_nodes[snid]->node->attrs.name
<< ". Excluding nodes " << excluded_node_names << "and retrying";
}
++count;
}
if (!success) {
LOG(INFO) << "Tried " << count << " times of finding subgraphs starting from node "
<< simple_nodes[snid]->node->attrs.name
<< " without success because a loop "
"is always found between the subgraph and some other nodes. Will treat "
"seed node "
<< simple_nodes[snid]->node->attrs.name << "as a subgraph with one node";
CHECK(subgraph_nodes->empty());
simple_nodes[snid]->label = label;
subgraph_nodes->push_back(simple_nodes[snid].get());
}
}
void SelectSubgraphNodes(nnvm::Graph* g, SubgraphSelectorV2Ptr subgraph_selector,
const std::vector<BiDirectedNodePtr>& simple_nodes,
std::vector<std::vector<BiDirectedNode*>>* subgraph_nodes,
std::vector<SubgraphSelectorV2Ptr>* subgraph_selectors,
const BiDirectedNode* node, const size_t snid, size_t* subgraph_id) {
const auto& indexed_graph = g->indexed_graph();
auto node_cmp = [&] (const BiDirectedNode* node1, const BiDirectedNode* node2) {
return indexed_graph.node_id(node1->node) < indexed_graph.node_id(node2->node);
};
if (simple_nodes[snid]->label == -1 && subgraph_selector->Select(*node)) {
// pre-select nodes that can be grouped in a subgraph
std::vector<BiDirectedNode*> preselected_nodes;
PreSelectSubgraphNodes(*g, subgraph_selector, *subgraph_id, snid, simple_nodes,
&preselected_nodes);
// filter out unqualified pre-selected nodes
std::vector<BiDirectedNode*> filtered_nodes = subgraph_selector->Filter(preselected_nodes);
// reset node labels that are not in filtered nodes
for (const auto n : preselected_nodes) {
const auto nit = std::find(filtered_nodes.begin(), filtered_nodes.end(), n);
if (nit == filtered_nodes.end()) {
n->label = -1;
}
}
if (filtered_nodes.size()) {
// make sure filtered_nodes is a subset of preselected_nodes
for (const auto n : filtered_nodes) {
const auto nit = std::find(preselected_nodes.begin(), preselected_nodes.end(), n);
CHECK(nit != preselected_nodes.end())
<< "Node " << n->node->attrs.name
<< " is not found in the pre-selected subgraph nodes."
" Please make sure that no new nodes were added in your subgraph"
" selector's Filter function";
}
// make sure nodes are sorted
std::sort(filtered_nodes.begin(), filtered_nodes.end(), node_cmp);
subgraph_nodes->push_back(filtered_nodes);
subgraph_selectors->push_back(subgraph_selector);
(*subgraph_id)++;
}
}
}
/*!
* \brief Finds subgraphs with all nodes that meet certain criteria.
* All nodes in a subgraph are marked with the same label.
*/
void FindSubgraphs(nnvm::Graph* g,
const SubgraphProperty &subg_prop,
const std::vector<BiDirectedNodePtr>& simple_nodes,
std::vector<std::vector<BiDirectedNode*>>* subgraph_nodes,
std::vector<SubgraphSelectorV2Ptr>* subgraph_selectors) {
const auto& indexed_graph = g->indexed_graph();
CHECK_EQ(indexed_graph.num_nodes(), simple_nodes.size());
size_t subgraph_id = 0;
for (size_t i = 0; i < simple_nodes.size(); ++i) {
const auto snode = simple_nodes[i];
SubgraphSelectorV2Ptr subgraph_selector = subg_prop.CreateSubgraphSelectorV2();
SelectSubgraphNodes(g, subgraph_selector, simple_nodes, subgraph_nodes, subgraph_selectors,
snode.get(), i, &subgraph_id);
}
}
/*!
* \brief Sorts entries according to their topological order.
* Note that entry ids cannot be used to sort entries.
* \param entry_top_order_map mapping from entry pointer to its topological position in the graph
* \param entries Node entries to be sorted
*/
void SortEntries(const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry_top_order_map,
std::vector<nnvm::NodeEntry*>* entries) {
auto entry_cmp = [&](const nnvm::NodeEntry* e1, const nnvm::NodeEntry* e2) {
const auto it1 = entry_top_order_map.find(e1);
CHECK(it1 != entry_top_order_map.end());
const auto it2 = entry_top_order_map.find(e2);
CHECK(it2 != entry_top_order_map.end());
return it1->second < it2->second;
};
std::sort(entries->begin(), entries->end(), entry_cmp);
}
/*!
* \brief Given a subgraph, find the output entries of a subgraph.
* \param g pointer to the whole graph
* \param simple_nods vector of simple nodes in top sorted order
* \param subgraph_nodes vector of pointers of simples of a subgraph.
* \param entry_top_order_map mapping entry pointer to its top sorted position
* \param input_entries input entries of the subgraph
*/
void FindInputEntries(const nnvm::Graph& g,
const std::vector<BiDirectedNodePtr>& simple_nodes,
const std::vector<BiDirectedNode*>& subgraph_nodes,
const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry_top_order_map,
std::vector<nnvm::NodeEntry*>* input_entries) {
const auto& indexed_graph = g.indexed_graph();
int label = -1;
for (auto subgraph_node : subgraph_nodes) {
if (label == -1) {
label = subgraph_node->label;
} else {
CHECK_EQ(subgraph_node->label, label);
}
auto& inputs = subgraph_node->node->inputs;
for (auto &e : inputs) {
if (indexed_graph.exist(e.node.get())) {
// e's source node is not a subgraph node
const auto nid = indexed_graph.node_id(e.node.get());
// this is a node not belonging to the subgraph
if (simple_nodes[nid]->label != label) {
input_entries->push_back(&e);
}
} else {
// e's source node is a subgraph node.
// In this case, two subgraphs are adjacent.
input_entries->push_back(&e);
}
}
}
SortEntries(entry_top_order_map, input_entries);
}
/*!
* \brief Given a subgraph, find the output entries of a subgraph.
* \param g pointer to the whole graph
* \param simple_nods vector of simple nodes in top sorted order
* \param subgraph_nodes vector of pointers of simples of a subgraph.
* \param entry_top_order_map mapping entry pointer to its top sorted position
* \param output_entries output entries of the subgraph
*/
void FindOutputEntries(nnvm::Graph* g,
const std::vector<BiDirectedNodePtr>& simple_nodes,
const std::vector<BiDirectedNode*>& subgraph_nodes,
const std::unordered_map<const nnvm::NodeEntry*, size_t>&
entry_top_order_map,
std::vector<nnvm::NodeEntry*>* output_entries) {
if (subgraph_nodes.empty()) return;
const auto& indexed_graph = g->indexed_graph();
int label = -1;
for (auto subgraph_node : subgraph_nodes) {
if (label == -1) {
label = subgraph_node->label;
} else {
CHECK_EQ(subgraph_node->label, label);
}
for (auto &output_node : subgraph_node->outputs) {
if (indexed_graph.exist(output_node.first)) {
// if the output node is a normal graph node (not a subgraph node)
const auto nid = indexed_graph.node_id(output_node.first);
// this is a node not belonging to the current subgraph
if (simple_nodes[nid]->label != label) {
for (auto idx : output_node.second) {
auto& e = simple_nodes[nid]->node->inputs[idx];
output_entries->push_back(&e);
}
}
} else {
// if the output node is a subgraph node
// two graphs are adjacent
for (auto idx : output_node.second) {
output_entries->push_back(&(output_node.first->inputs[idx]));
}
}
}
}
// Check if current subgraph contains a node which is the last node
// of the whole graph. If so, save its corresponding entry as well.
for (auto &entry : g->outputs) {
// The entry might has been updated as an output of
// a subgraph node. In this case, no need
// to check its source for the current subgraph. Otherwise,
// do the following.
if (indexed_graph.exist(entry.node.get())) {
const auto nid = indexed_graph.node_id(entry.node.get());
if (simple_nodes[nid]->label == label) {
output_entries->push_back(&entry);
}
}
}
SortEntries(entry_top_order_map, output_entries);
}
/*!
* \brief Given a computation graph and a set of input node entries, this function cuts
* the node entries and creates new variable nodes as the input nodes of the
* subgraph. It returns the nodes that connect to the subgraph directly and
* the names of the new variable nodes.
*/
void CutGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
std::vector<nnvm::NodeEntry> *orig_entries,
const bool skip_var = false) {
orig_entries->resize(input_entries.size());
// map for creating unique var nodes for deduplicating entries from the same node
std::unordered_map<std::string, int> name_count_map;
for (size_t i = 0; i < input_entries.size(); ++i) {
nnvm::NodeEntry *e = input_entries[i];
// If the node is a variable itself, we may want to skip the node.
if (e->node->is_variable() && skip_var) {
continue;
}
orig_entries->at(i) = *e;
nnvm::Symbol sym;
sym.outputs.push_back(*e);
const auto output_names = sym.ListOutputNames();
CHECK_EQ(output_names.size(), 1U);
const std::string& var_name = output_names[0];
auto it = name_count_map.find(var_name);
if (name_count_map.end() == it) {
name_count_map.emplace(var_name, 0);
} else {
++(it->second);
}
nnvm::NodePtr n = nnvm::CreateVariableNode(var_name + std::to_string(name_count_map[var_name]));
*e = nnvm::NodeEntry{n, 0, 0};
}
}
/*!
* \brief Replace a set of nodes belonging to the same subgraph with a subgrpah node
* and keep the subgraph in the subgraph node.
*/
void CreateSubgraphNode(nnvm::Graph* g,
const std::vector<BiDirectedNodePtr>& simple_nodes,
const std::vector<BiDirectedNode*>& subgraph_nodes,
const SubgraphSelectorV2Ptr& subgraph_selector,
const size_t subgraph_id,
std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
#if DEBUG_SUBGRAPH
LOG(INFO) << "Searching for input entries...";
#endif
std::vector<nnvm::NodeEntry*> input_entries;
FindInputEntries(*g, simple_nodes, subgraph_nodes, *entry_top_order_map, &input_entries);
std::vector<nnvm::NodeEntry> orig_input_entries;
CutGraphInputs(input_entries, &orig_input_entries, false);
#if DEBUG_SUBGRAPH
PrintNodeEntries(input_entries);
LOG(INFO) << "Searching for output entries...";
#endif
std::vector<nnvm::NodeEntry*> output_entries;
FindOutputEntries(g, simple_nodes, subgraph_nodes, *entry_top_order_map, &output_entries);
// Create a subgraph for the subgraph node
nnvm::Symbol sym;
sym.outputs.resize(output_entries.size());
for (size_t i = 0; i < output_entries.size(); ++i) {
sym.outputs[i] = *output_entries[i];
}
const SubgraphPropertyPtr& subg_prop = g->GetAttr<SubgraphPropertyPtr>("subgraph_property");
nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym, subgraph_selector, subgraph_id);
// Connect the external nodes to the subgraph node.
subg_prop->ConnectSubgraphOutputs(n, &output_entries);
subg_prop->ConnectSubgraphInputs(n, &input_entries, &orig_input_entries);
const auto& indexed_graph = g->indexed_graph();
for (size_t i = 0; i < n->inputs.size(); ++i) {
auto& e = n->inputs[i];
// update entry_top_order_map with newly created orig_input_entries
auto it = entry_top_order_map->find(input_entries[i]);
CHECK(it != entry_top_order_map->end());
entry_top_order_map->emplace(&e, it->second);
// update input entries' source simple nodes' outputs map
nnvm::Node* node = e.node.get();
if (indexed_graph.exist(node)) {
const auto nid = indexed_graph.node_id(node);
BiDirectedNode* sn = simple_nodes[nid].get();
for (BiDirectedNode* dest_node : subgraph_nodes) {
sn->outputs.erase(dest_node->node);
}
sn->outputs[n.get()].push_back(i);
}
}
#if DEBUG_SUBGRAPH
PrintNodeEntries(output_entries);
#endif
}
/*!
* \brief Adjust a set of nodes belonging to the same subgraph. No new node is created, but
* adjust selected nodes' attributes.
* This can be used to implement peephole optimization. For example, adjust calibration information
* of quantized nodes.
*/
void AdjustSubgraphNode(nnvm::Graph* g,
const std::vector<BiDirectedNode*>& subgraph_nodes,
const SubgraphSelectorV2Ptr& subgraph_selector,
const size_t subgraph_id) {
std::vector<nnvm::Node*> node_list;
for (auto node : subgraph_nodes) {
node_list.push_back(node->node);
}
const SubgraphPropertyPtr& subg_prop = g->GetAttr<SubgraphPropertyPtr>("subgraph_property");
subg_prop->AdjustSubgraphNode(node_list, subgraph_selector, subgraph_id);
}
} // namespace sg
/*!
* \brief Sort entries of all the nodes' inputs vectors in the topological order.
* This is going to be used to sort input/output entries of subgraphs to keep
* the topological order unchanged.
*/
void TopSortEntries(const nnvm::Graph& g,
std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
CHECK(entry_top_order_map != nullptr);
std::unordered_set<const nnvm::Node*> visited;
// tuple: (graph node, index of node's inputs, node entry as the output of the graph node)
std::stack<std::tuple<nnvm::Node*, size_t, const nnvm::NodeEntry*>> s;
auto in_degree = [] (const nnvm::Node* node)->size_t {
if (!node) {
return 0;
}
CHECK_EQ(node->control_deps.size(), 0U);
return node->inputs.size();
};
for (auto& e : g.outputs) {
nnvm::Node* node = e.node.get();
if (visited.count(node) == 0U) {
s.emplace(node, 0U, &e);
visited.insert(node);
} else {
// The entry's source node has been visited before.
// Marking the order for it.
entry_top_order_map->emplace(&e, entry_top_order_map->size());
}
while (!s.empty()) {
auto& top = s.top();
if (std::get<1>(top) == in_degree(std::get<0>(top))) {
// The node's inputs has been exhausted.
entry_top_order_map->emplace(std::get<2>(top), entry_top_order_map->size());
s.pop();
} else {
// The node still has input entries not visited.
CHECK_LT(std::get<1>(top), std::get<0>(top)->inputs.size());
auto& entry = std::get<0>(top)->inputs[std::get<1>(top)++];
nnvm::Node* input_node = entry.node.get();
if (visited.count(input_node) == 0U) {
// The entry's source node has not been visited.
// Push the entry to the stack for marking order later.
s.emplace(input_node, 0U, &entry);
visited.insert(input_node);
} else {
// The entry's source node has been visited before.
// Marking the order for it.
entry_top_order_map->emplace(&entry, entry_top_order_map->size());
}
}
}
}
}
nnvm::Graph BuildSubgraph(nnvm::Graph&& g) {
if (!g.HasAttr("subgraph_property")) { // treat the whole graph as a subgraph
LOG(INFO) << "The graph has no attribute of subgraph_property attached. "
"The original graph is returned.";
return g;
}
using namespace sg;
const SubgraphPropertyPtr& subg_prop = g.GetAttr<SubgraphPropertyPtr>("subgraph_property");
const std::string& prop_name = subg_prop->HasAttr("property_name")
? subg_prop->GetAttr<std::string>("property_name")
: "partition graph";
LOG(INFO) << "start to execute " << prop_name << ".";
// top sort NodeEntry of all the nodes' inputs
std::unordered_map<const nnvm::NodeEntry*, size_t> entry_top_order_map;
TopSortEntries(g, &entry_top_order_map);
// Create double directional graph for ease of finding subgraphs
std::vector<BiDirectedNodePtr> simple_nodes;
CreateSimpleGraph(g, &simple_nodes);
std::vector<std::vector<BiDirectedNode*>> subgraph_nodes;
std::vector<SubgraphSelectorV2Ptr> subgraph_selectors;
FindSubgraphs(&g, *subg_prop, simple_nodes, &subgraph_nodes, &subgraph_selectors);
CHECK_EQ(subgraph_nodes.size(), subgraph_selectors.size());
for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
#if DEBUG_SUBGRAPH
std::set<BiDirectedNode*> simple_node_set(subgraph_nodes[i].begin(), subgraph_nodes[i].end());
CHECK_EQ(simple_node_set.size(), subgraph_nodes[i].size());
PrintSubgraph(subgraph_nodes[i]);
#endif
auto ptype = subg_prop->GetPropertyType();
if (ptype == SubgraphProperty::SgPropertyType::kCreate) {
CreateSubgraphNode(&g, simple_nodes, subgraph_nodes[i], subgraph_selectors[i], i,
&entry_top_order_map);
} else {
CHECK_EQ(ptype, SubgraphProperty::SgPropertyType::kAdjust);
AdjustSubgraphNode(&g, subgraph_nodes[i], subgraph_selectors[i], i);
}
}
return g;
}
NNVM_REGISTER_PASS(BuildSubgraph)
.describe("Apply a subgraph pass according to the user defined rules "
"in a derived class of SubgraphProperty")
.set_body(BuildSubgraph)
.set_change_graph(true);
} // namespace op
} // namespace mxnet