blob: 031e62254db7b4e9795eab066b3351efa07a2d63 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2016 by Contributors
* \file graph_algorithm.h
* \brief This header contains graph algorithms on StaticGraph.
* It is used compute informations such as whether two
* operations can run in parallel, and helps allocation.
*/
#ifndef MXNET_NNVM_GRAPH_ALGORITHM_H_
#define MXNET_NNVM_GRAPH_ALGORITHM_H_
#include <nnvm/graph.h>
#include <vector>
namespace nnvm {
namespace pass {
/*!
* \brief Find best path in the DAG, with reward defined
* by sum of reward of each node along the path.
* \param graph the original static graph.
* \param topo_order topo order of the nodes in the graph.
* \param node_reward the reward of each node.
* \param path the output path of nodes.
* \return the total reward of best path.
*/
inline uint32_t MXFindBestPath(
const IndexedGraph& graph,
const std::vector<uint32_t>& node_reward,
std::vector<uint32_t>* path) {
const uint32_t num_nodes = static_cast<uint32_t>(graph.num_nodes());
CHECK_EQ(num_nodes, node_reward.size());
std::vector<uint32_t> best_reward(node_reward.size(), 0);
std::vector<uint32_t> next_node(node_reward.size(), num_nodes);
uint32_t best_solution = 0, best_start_node = 0;
// traverse in reverse topo order
for (uint32_t i = static_cast<uint32_t>(graph.num_nodes()); i != 0; --i) {
const uint32_t nid = i - 1;
best_reward[nid] += node_reward[nid];
if (best_reward[nid] > best_solution) {
best_solution = best_reward[nid];
best_start_node = nid;
}
for (const auto& e : graph[nid].inputs) {
const uint32_t prev = e.node_id;
if (best_reward[nid] > best_reward[prev]) {
best_reward[prev] = best_reward[nid];
next_node[prev] = nid;
}
}
}
path->clear();
uint32_t reward = 0;
for (uint32_t nid = best_start_node; nid < num_nodes; nid = next_node[nid]) {
path->push_back(nid); reward += node_reward[nid];
}
CHECK_EQ(reward, best_solution);
return best_solution;
}
/*!
* \brief Color the nodes in the graph into index.
* The coloring algorithm tries to assign node group
* such that node in the same group cannot run in parallel.
*
* \param graph the original indexed graph.
* \param node_importance The importance of the node
* \param max_ncolor maximum number of colors allowed.
* \param color the color index of each of the node.
* \return the total number of colors.
*/
inline uint32_t MXColorNodeGroup(
const IndexedGraph &graph,
std::vector<uint32_t> node_importance,
uint32_t max_ncolor,
std::vector<uint32_t> *color) {
CHECK_NE(max_ncolor, 0U);
CHECK_EQ(graph.num_nodes(), node_importance.size());
color->clear();
color->resize(graph.num_nodes(), max_ncolor);
uint32_t cindex;
// greedy algorithm, every time
// find a path with best reward and assign a new color
// All the nodes in the path cannot run in parallel.
for (cindex = 0; cindex < max_ncolor - 1; ++cindex) {
std::vector<uint32_t> path;
uint32_t reward = MXFindBestPath(graph, node_importance, &path);
if (reward == 0) break;
for (uint32_t nid : path) {
if (node_importance[nid] != 0) {
CHECK_EQ(color->at(nid), max_ncolor);
color->at(nid) = cindex;
// make the importance 0 after color is decided.
node_importance[nid] = 0;
}
}
}
// assign i for rest of the node
for (uint32_t i = 0; i < graph.num_nodes(); ++i) {
if (color->at(i) == max_ncolor) {
color->at(i) = cindex;
}
}
return cindex + 1;
}
} // namespace pass
} // namespace nnvm
#endif // MXNET_NNVM_GRAPH_ALGORITHM_H_