| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file graph_attr_types.cc |
| * \brief Graph node data structure. |
| */ |
| #include <nnvm/graph.h> |
| #include <nnvm/op_attr_types.h> |
| |
| #include <limits> |
| |
| namespace nnvm { |
| |
| const IndexedGraph& Graph::indexed_graph() const { |
| if (indexed_graph_ == nullptr) { |
| indexed_graph_.reset(new IndexedGraph(*this)); |
| } |
| return *indexed_graph_; |
| } |
| |
| // a subgraph should not refer to any nodes with higher level |
| // where "level" refers to the nested depth of the subgraph |
| // e.g. the main graph is level 0 |
| // subgraphs of the main graph is level 1 |
| // subgraphs of the subgraphs of the main graph is level 2 |
| static void SubgraphSanityCheck(const std::vector<std::shared_ptr<Symbol>>& subgraphs) { |
| std::vector<const std::vector<nnvm::NodeEntry>*> curr_level; |
| std::vector<const std::vector<nnvm::NodeEntry>*> next_level; |
| std::unordered_map<nnvm::Node*, uint32_t> node2level; |
| for (auto& subgraph : subgraphs) next_level.push_back(&subgraph->outputs); |
| for (uint32_t level = 0; !next_level.empty(); ++level) { |
| curr_level.swap(next_level); |
| next_level.clear(); |
| for (const std::vector<NodeEntry>* graph_ptr : curr_level) { |
| const std::vector<NodeEntry>& graph = *graph_ptr; |
| DFSVisit(graph, [&next_level, &node2level, level](const ObjectPtr& n) { |
| nnvm::Node* node = n.get(); |
| // if the node is visited, but on a different level, then check failed |
| // if check failed here or before, we stop doing anything, but raise an error |
| CHECK(!node2level.count(node) || node2level[node] == level) |
| << "A subgraph should not depend on the outputs of nodes on higher levels"; |
| // otherwise, this node belongs to the current level |
| node2level[node] = level; |
| // subgraphs of current node belongs to next level |
| for (const auto& subgraph : n->attrs.subgraphs) { |
| next_level.push_back(&subgraph->outputs); |
| } |
| }); |
| } |
| } |
| } |
| |
| // implement constructor from graph |
| IndexedGraph::IndexedGraph(const Graph& g) { |
| entry_rptr_.push_back(0); |
| std::vector<size_t> inputs_rptr{0}, control_rptr{0}; |
| std::vector<std::shared_ptr<Symbol>> subgraphs; |
| |
| DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs](const ObjectPtr& n) { |
| const auto& is_ghost = Op::GetAttr<TIsGhost>("TIsGhost"); |
| if (!n->is_variable() && is_ghost.get(n->op(), false)) return; |
| CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max()); |
| uint32_t nid = static_cast<uint32_t>(nodes_.size()); |
| CHECK(n); |
| for (const auto& subgraph : n->attrs.subgraphs) subgraphs.push_back(subgraph); |
| // nodes_ |
| IndexedGraph::Node new_node; |
| new_node.source = n.get(); |
| new_node.weak_ref = n; |
| nodes_.emplace_back(std::move(new_node)); |
| // arg_nodes_ |
| if (n->is_variable()) { |
| input_nodes_.push_back(nid); |
| } |
| // node2index_ |
| node2index_[n.get()] = nid; |
| // entry rptr |
| entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs()); |
| // input entries |
| for (const auto& e : n->inputs) { |
| auto it = node2index_.find(e.node.get()); |
| if (it == node2index_.end() || it->first != e.node.get()) continue; |
| input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version}); |
| } |
| inputs_rptr.push_back(input_entries_.size()); |
| // control deps |
| for (const auto& nptr : n->control_deps) { |
| if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue; |
| auto it = node2index_.find(nptr.get()); |
| CHECK(it != node2index_.end()) << "control dep not found in graph"; |
| control_deps_.push_back(it->second); |
| } |
| control_rptr.push_back(control_deps_.size()); |
| }); |
| if (!subgraphs.empty()) SubgraphSanityCheck(subgraphs); |
| |
| for (const auto& e : g.outputs) { |
| outputs_.emplace_back(NodeEntry{node2index_.at(e.node.get()), e.index, e.version}); |
| } |
| |
| static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs"); |
| // setup array view |
| // input_entries_ and control_rptr must not change after this step. |
| const NodeEntry* iptr = dmlc::BeginPtr(input_entries_); |
| for (size_t nid = 0; nid < nodes_.size(); ++nid) { |
| nodes_[nid].inputs = |
| array_view<NodeEntry>(iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]); |
| if (nodes_[nid].source->op() != nullptr && fmutate_inputs.count(nodes_[nid].source->op())) { |
| for (uint32_t i : fmutate_inputs[nodes_[nid].source->op()](nodes_[nid].source->attrs)) { |
| mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id); |
| } |
| } |
| } |
| const uint32_t* cptr = dmlc::BeginPtr(control_deps_); |
| for (size_t nid = 0; nid < nodes_.size(); ++nid) { |
| nodes_[nid].control_deps = |
| array_view<uint32_t>(cptr + control_rptr[nid], cptr + control_rptr[nid + 1]); |
| } |
| } |
| |
| } // namespace nnvm |