blob: 4f59635fd3e93bb9fc7ff0a2829c8c06ff83fda6 [file] [log] [blame]
/************************************************************
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*************************************************************/
#include "singa/utils/graph.h"
#include <glog/logging.h>
#include <algorithm>
#include <queue>
#include <unordered_set>
#include "singa/utils/common.h"
namespace singa {
using std::map;
using std::string;
using std::vector;
/**************************************************************************
* Implementation for Node class
*************************************************************************/
Node::Node(string name) {
this->name = name;
}
Node::Node(string name, const std::map<string, string>& attrs) {
this->name = name;
this->attrs = attrs;
}
Node::Node(const string& name, const string& origin, int id, void* proto) {
this->name = name;
this->origin = origin;
this->proto = proto;
this->partition_id = id;
}
void Node::AddDstNode(Node* dstnode) {
dstnodes.push_back(dstnode);
}
void Node::AddSrcNode(Node* srcnode) {
srcnodes.push_back(srcnode);
}
void Node::RemoveDstNode(Node* dst) {
auto iter = dstnodes.begin();
while ((*iter)->name != dst->name && iter != dstnodes.end())
iter++;
CHECK_STREQ((*iter)->name.c_str(), dst->name.c_str());
dstnodes.erase(iter);
}
void Node::RemoveSrcNode(Node* src) {
auto iter = srcnodes.begin();
while ((*iter)->name != src->name && iter != srcnodes.end())
iter++;
CHECK_STREQ((*iter)->name.c_str(), src->name.c_str());
srcnodes.erase(iter);
}
/****************************************************************************
* Implementation for Graph class
****************************************************************************/
Graph::~Graph() {
for (Node* node : nodes_)
delete node;
}
Node* Graph::AddNode(const string& name, const string& origin, int id,
void* proto) {
Node* node = new Node(name, origin, id, proto);
nodes_.push_back(node);
CHECK(name2node_.find(node->name) == name2node_.end())
<< "node " << node->name << " already exists";
name2node_[node->name] = node;
return node;
}
Node* Graph::AddNode(const string& name,
const std::map<string, string>& attrs) {
Node* node = new Node(name, attrs);
nodes_.push_back(node);
CHECK(name2node_.find(node->name) == name2node_.end())
<< "node " << node->name << " already exists";
name2node_[node->name] = node;
return node;
}
void Graph::AddEdge(Node* srcnode, Node* dstnode) {
srcnode->AddDstNode(dstnode);
dstnode->AddSrcNode(srcnode);
}
void Graph::AddEdge(const string& src, const string& dst) {
auto srcnode = name2node_.find(src);
CHECK(srcnode != name2node_.end()) << "can't find src node " << src;
auto dstnode = name2node_.find(dst);
CHECK(dstnode != name2node_.end()) << "can't find dst node " << dst;
AddEdge(srcnode->second, dstnode->second);
}
void Graph::AddEdge(Node* srcnode, Node* dstnode,
const std::map<string, string>& attrs) {
AddEdge(srcnode, dstnode);
edge_attrs_[GetEdgeName(srcnode->name, dstnode->name)] = attrs;
}
void Graph::AddEdge(const string& src, const std::string& dst,
const std::map<string, string>& attrs) {
AddEdge(src, dst);
edge_attrs_[GetEdgeName(src, dst)] = attrs;
}
void Graph::RemoveEdge(Node* src, Node* dst) {
src->RemoveDstNode(dst);
dst->RemoveSrcNode(src);
}
void Graph::RemoveEdge(const string &src, const string& dst) {
auto srcnode = name2node_.find(src);
CHECK(srcnode != name2node_.end()) << "can't find src node " << src;
auto dstnode = name2node_.find(dst);
CHECK(dstnode != name2node_.end()) << "can't find dst node " << dst;
RemoveEdge(srcnode->second, dstnode->second);
}
// sort to make `bottom' nodes be placed in the front positions
void Graph::Sort() {
// nodes to be visited
std::queue<Node*> visiting_nodes;
// visited node set
std::unordered_set<Node*> visited_set;
// visiting_nodes + visted_set
std::unordered_set<Node*> visit_set;;
for (auto node : nodes_) {
// visit nodes without source nodes firstly
if (node->srcnodes.size() == 0) {
visiting_nodes.push(node);
visit_set.insert(node);
}
}
int n = nodes_.size();
nodes_.clear();
while (!visiting_nodes.empty()) {
auto node = visiting_nodes.front();
visiting_nodes.pop();
bool visit = true;
bool bi_direction = false;
// check if a node has a bi-direction edge with its neighbour
for (auto src : node->srcnodes)
for (auto src_of_src : src->srcnodes)
if (strcmp((src_of_src->name).c_str(), (node->name).c_str()) == 0) {
bi_direction = true;
break;
}
// check whether its src nodes number greater than 1
if (bi_direction && (node->srcnodes).size() > 1) {
auto src = node->srcnodes.at(0);
if (visited_set.find(src) == visited_set.end()) {
visit = false;
}
} else {
for (auto src : node->srcnodes)
if (visited_set.find(src) == visited_set.end()) {
visit = false;
break;
}
}
if (visit) {
nodes_.push_back(node);
visited_set.insert(node);
for (auto dst : node->dstnodes) {
// queueing the dst node if it is not queued before
if (visit_set.find(dst) == visit_set.end()) {
visiting_nodes.push(dst);
visit_set.insert(dst);
}
}
} else {
visiting_nodes.push(node);
}
}
CHECK_EQ(nodes_.size(), n);
}
const Graph Graph::Reverse() const {
Graph g;
for (Node* n : nodes_)
g.AddNode(n->name, n->attrs);
for (Node* src : nodes_)
for (Node* dst : src->dstnodes) {
map<string, string> attrs;
const string edge = GetEdgeName(src->name, dst->name);
if (edge_attrs_.find(edge) != edge_attrs_.end())
attrs = edge_attrs_.at(edge);
g.AddEdge(dst->name, src->name, attrs);
}
return g;
}
string Graph::ToJson() const {
map<string, string> label;
return ToJson(label);
}
string Graph::ToJson(const map<string, string>& label) const {
string disp = "{\"directed\":1,\n";
// add nodes
disp += "\"nodes\":[\n";
bool first = true;
map<string, int> node_id;
int id = 0;
for (auto node : nodes_) {
string name = node->name;
string lbl = name + " -- ";
if (label.find(name) != label.end())
lbl += label.at(name);
if (node->attrs.find("label") != node->attrs.end())
lbl += node->attrs.at("label");
disp += StringPrintf("%c{\"id\":\"%s\", \"label\":\"%s\"",
!first ? ',' : ' ', name.c_str(), lbl.c_str());
for (const auto& attr : node->attrs)
if (attr.first != "label")
disp += StringPrintf(", \"%s\":\"%s\"",
attr.first.c_str(), attr.second.c_str());
disp += "}\n";
first = false;
node_id[name] = id++;
}
disp += "]\n,\n";
// add edges
disp += "\"links\":[\n";
first = true;
for (auto src : nodes_) {
for (auto dst : src->dstnodes) {
const string edge_name = GetEdgeName(src->name, dst->name);
string lbl = "";
if (label.find(edge_name) != label.end())
lbl = label.at(edge_name);
disp += StringPrintf("%c{\"source\":%d, \"target\":%d, \"label\": \"%s\"",
!first ? ',' : ' ', node_id[src->name], node_id[dst->name],
lbl.c_str());
if (edge_attrs_.find(edge_name) != edge_attrs_.end()) {
for (const auto& attr : edge_attrs_.at(edge_name))
disp += StringPrintf(", \"%s\":\"%s\"",
attr.first.c_str(), attr.second.c_str());
}
disp += "}\n";
first = false;
}
}
return disp + "]}";
}
} // namespace singa