blob: 27680b7134da741f405e6eba5c3ce476562466b9 [file] [log] [blame]
/************************************************************
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*************************************************************/
#include "utils/graph.h"
#include <glog/logging.h>
#include <algorithm>
#include <queue>
#include <unordered_set>
namespace singa {
using std::map;
using std::string;
using std::vector;
Node::Node(string name) {
this->name = name;
}
Node::Node(const string& name, const string& origin, int id, void* proto) {
this->name = name;
this->origin = origin;
this->proto = proto;
this->partition_id = id;
}
void Node::AddDstNode(Node* dstnode) {
dstnodes.push_back(dstnode);
}
void Node::AddSrcNode(Node* srcnode) {
srcnodes.push_back(srcnode);
}
void Node::RemoveDstNode(Node* dst) {
auto iter = dstnodes.begin();
while ((*iter)->name != dst->name && iter != dstnodes.end())
iter++;
CHECK_STREQ((*iter)->name.c_str(), dst->name.c_str());
dstnodes.erase(iter);
}
void Node::RemoveSrcNode(Node* src) {
auto iter = srcnodes.begin();
while ((*iter)->name != src->name && iter != srcnodes.end())
iter++;
CHECK_STREQ((*iter)->name.c_str(), src->name.c_str());
srcnodes.erase(iter);
}
Graph::~Graph() {
for (Node* node : nodes_)
delete node;
}
void Graph::AddNode(Node* node) {
nodes_.push_back(node);
CHECK(name2node_.find(node->name) == name2node_.end())
<< "node " << node->name << " already exists";
name2node_[node->name] = node;
}
Node* Graph::AddNode(const string& name) {
Node* node = new Node(name);
AddNode(node);
return node;
}
void Graph::AddEdge(Node* srcnode, Node* dstnode) {
srcnode->AddDstNode(dstnode);
dstnode->AddSrcNode(srcnode);
}
void Graph::AddEdge(const string& src, const string& dst) {
auto srcnode = name2node_.find(src);
CHECK(srcnode != name2node_.end()) << "can't find src node " << src;
auto dstnode = name2node_.find(dst);
CHECK(dstnode != name2node_.end()) << "can't find dst node " << dst;
AddEdge(srcnode->second, dstnode->second);
}
void Graph::RemoveEdge(Node* src, Node* dst) {
src->RemoveDstNode(dst);
dst->RemoveSrcNode(src);
}
void Graph::RemoveEdge(const string &src, const string& dst) {
auto srcnode = name2node_.find(src);
CHECK(srcnode != name2node_.end()) << "can't find src node " << src;
auto dstnode = name2node_.find(dst);
CHECK(dstnode != name2node_.end()) << "can't find dst node " << dst;
RemoveEdge(srcnode->second, dstnode->second);
}
string Graph::ToJson() const {
map<string, string> info;
return ToJson(info);
}
string Graph::ToJson(const map<string, string>& info) const {
map<string, int> nodeid;
string disp = "{\"directed\":1,\n";
// add nodes
disp += "\"nodes\":[\n";
bool first = true;
vector<string> colors = {"red", "blue", "black", "green"};
// see for more shapes at http://www.graphviz.org/doc/info/shapes.html
vector<string> shapes = {"box", "ellipse"};
int id = 0;
for (auto node : nodes_) {
char str[1024];
string name = node->name;
string color = colors[(node->partition_id)%colors.size()];
string shape;
string origin = node->origin;
if (origin.find("##") != string::npos)
shape = shapes[1];
else
shape = shapes[0];
snprintf(str, sizeof(str),
"{\"id\":\"%s%s\", \"color\":\"%s\",\"shape\":\"%s\"}\n", name.c_str(),
info.find(name) != info.end() ? info.at(name).c_str() : "",
color.c_str(), shape.c_str());
if (!first)
disp += ",";
else
first = false;
disp += string(str);
nodeid[name] = id++;
}
disp += "]\n,";
// add edges
disp += "\"links\":[\n";
first = true;
for (auto src : nodes_) {
for (auto dst : src->dstnodes) {
char str[1024];
snprintf(str, sizeof(str),
"{\"source\":%d, \"target\":%d, \"color\":\"%s\"}\n",
nodeid[src->name], nodeid[dst->name], "black");
if (!first)
disp += ",";
else
first = false;
disp += string(str);
}
}
disp += "]\n";
return disp+"}";
}
// sort to make `bottom' nodes be placed in the front positions
void Graph::Sort() {
// nodes to be visited
std::queue<Node*> visiting_nodes;
// visited node set
std::unordered_set<Node*> visited_set;
// visiting_nodes + visted_set
std::unordered_set<Node*> visit_set;;
for (auto node : nodes_) {
// visit nodes without source nodes firstly
if (node->srcnodes.size() == 0) {
visiting_nodes.push(node);
visit_set.insert(node);
}
}
int n = nodes_.size();
nodes_.clear();
while (!visiting_nodes.empty()) {
auto node = visiting_nodes.front();
visiting_nodes.pop();
bool visit = true;
bool bi_direction = false;
// check if a node has a bi-direction edge with its neighbour
for (auto src : node->srcnodes)
for (auto src_of_src : src->srcnodes)
if (strcmp((src_of_src->name).c_str(), (node->name).c_str()) == 0) {
bi_direction = true;
break;
}
// check whether its src nodes number greater than 1
if (bi_direction && (node->srcnodes).size() > 1) {
auto src = node->srcnodes.at(0);
if (visited_set.find(src) == visited_set.end()) {
visit = false;
}
} else {
for (auto src : node->srcnodes)
if (visited_set.find(src) == visited_set.end()) {
visit = false;
break;
}
}
if (visit) {
nodes_.push_back(node);
visited_set.insert(node);
for (auto dst : node->dstnodes) {
// queueing the dst node if it is not queued before
if (visit_set.find(dst) == visit_set.end()) {
visiting_nodes.push(dst);
visit_set.insert(dst);
}
}
} else {
visiting_nodes.push(node);
}
}
for (auto node : nodes_) {
LOG(INFO) << "nodes: " << node->name;
}
LOG(INFO) << "finish printing nodes ";
CHECK_EQ(nodes_.size(), n);
}
} // namespace singa