blob: fd5b64f4777d820167fdea80f063c1d506f86da5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file graph_attr_types.cc
* \brief Graph node data structure.
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <limits>
namespace nnvm {
const IndexedGraph& Graph::indexed_graph() const {
if (indexed_graph_ == nullptr) {
indexed_graph_.reset(new IndexedGraph(*this));
}
return *indexed_graph_;
}
// a subgraph should not refer to any nodes with higher level
// where "level" refers to the nested depth of the subgraph
// e.g. the main graph is level 0
// subgraphs of the main graph is level 1
// subgraphs of the subgraphs of the main graph is level 2
static void SubgraphSanityCheck(const std::vector<std::shared_ptr<Symbol>>& subgraphs) {
std::vector<const std::vector<nnvm::NodeEntry>*> curr_level;
std::vector<const std::vector<nnvm::NodeEntry>*> next_level;
std::unordered_map<nnvm::Node*, uint32_t> node2level;
for (auto& subgraph : subgraphs) next_level.push_back(&subgraph->outputs);
for (uint32_t level = 0; !next_level.empty(); ++level) {
curr_level.swap(next_level);
next_level.clear();
for (const std::vector<NodeEntry>* graph_ptr : curr_level) {
const std::vector<NodeEntry>& graph = *graph_ptr;
DFSVisit(graph, [&next_level, &node2level, level](const ObjectPtr& n) {
nnvm::Node* node = n.get();
// if the node is visited, but on a different level, then check failed
// if check failed here or before, we stop doing anything, but raise an error
CHECK(!node2level.count(node) || node2level[node] == level)
<< "A subgraph should not depend on the outputs of nodes on higher levels";
// otherwise, this node belongs to the current level
node2level[node] = level;
// subgraphs of current node belongs to next level
for (const auto& subgraph : n->attrs.subgraphs) {
next_level.push_back(&subgraph->outputs);
}
});
}
}
}
// implement constructor from graph
IndexedGraph::IndexedGraph(const Graph& g) {
entry_rptr_.push_back(0);
std::vector<size_t> inputs_rptr{0}, control_rptr{0};
std::vector<std::shared_ptr<Symbol>> subgraphs;
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs](const ObjectPtr& n) {
const auto& is_ghost = Op::GetAttr<TIsGhost>("TIsGhost");
if (!n->is_variable() && is_ghost.get(n->op(), false)) return;
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size());
CHECK(n);
for (const auto& subgraph : n->attrs.subgraphs) subgraphs.push_back(subgraph);
// nodes_
IndexedGraph::Node new_node;
new_node.source = n.get();
new_node.weak_ref = n;
nodes_.emplace_back(std::move(new_node));
// arg_nodes_
if (n->is_variable()) {
input_nodes_.push_back(nid);
}
// node2index_
node2index_[n.get()] = nid;
// entry rptr
entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs());
// input entries
for (const auto& e : n->inputs) {
auto it = node2index_.find(e.node.get());
if (it == node2index_.end() || it->first != e.node.get()) continue;
input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version});
}
inputs_rptr.push_back(input_entries_.size());
// control deps
for (const auto& nptr : n->control_deps) {
if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue;
auto it = node2index_.find(nptr.get());
CHECK(it != node2index_.end()) << "control dep not found in graph";
control_deps_.push_back(it->second);
}
control_rptr.push_back(control_deps_.size());
});
if (!subgraphs.empty()) SubgraphSanityCheck(subgraphs);
for (const auto& e : g.outputs) {
outputs_.emplace_back(NodeEntry{node2index_.at(e.node.get()), e.index, e.version});
}
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
// setup array view
// input_entries_ and control_rptr must not change after this step.
const NodeEntry* iptr = dmlc::BeginPtr(input_entries_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
nodes_[nid].inputs =
array_view<NodeEntry>(iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
if (nodes_[nid].source->op() != nullptr && fmutate_inputs.count(nodes_[nid].source->op())) {
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op()](nodes_[nid].source->attrs)) {
mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
}
}
}
const uint32_t* cptr = dmlc::BeginPtr(control_deps_);
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
nodes_[nid].control_deps =
array_view<uint32_t>(cptr + control_rptr[nid], cptr + control_rptr[nid + 1]);
}
}
} // namespace nnvm