blob: 44a807eda1744eef214b411dc8b64e48d283a483 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file graph_editor.cc
* The functions in this file edit an NNVM graph. Potentially,
* these functions should be moved to NNVM in the future.
*/
#include <nnvm/symbolic.h>
#include <nnvm/graph.h>
#include <nnvm/node.h>
#include <utility>
namespace nnvm {
ObjectPtr CreateVariableNode(const std::string& name);
}
namespace mxnet {
/*
* Given a computation graph, this function finds the input nodes of the graph
* and create symbols for the input nodes. It returns the input symbols.
*/
std::vector<nnvm::Symbol *> GetInputSymbols(const nnvm::Symbol &sym) {
nnvm::Graph g;
std::vector<nnvm::Symbol *> input_syms;
g.outputs = sym.outputs;
const nnvm::IndexedGraph& idx = g.indexed_graph();
// Go through all nodes and return the ones representing variables.
for (size_t i = 0; i < idx.num_nodes(); i++) {
const nnvm::Node &n = *idx[i].source;
for (const nnvm::NodeEntry &e : n.inputs) {
auto p = e.node;
if (p->is_variable()) {
nnvm::Symbol *s = new nnvm::Symbol();
s->outputs.push_back(e);
input_syms.push_back(s);
}
}
}
return input_syms;
}
/*
* Given a computation graph and a set of input node entries, this function cuts
* the node entries and creates new variable nodes as the input nodes of the
* subgraph. It returns the nodes that connect to the subgraph directly and
* the names of the new variable nodes.
*/
bool CutGraphInputs(const std::vector<nnvm::NodeEntry *> &input_entries,
bool skip_var, std::vector<nnvm::NodeEntry> *orig_entries) {
struct pred_entry {
nnvm::NodeEntry e;
explicit pred_entry(nnvm::NodeEntry _e): e(std::move(_e)) {}
bool operator()(const nnvm::NodeEntry e1) {
return e.node == e1.node && e.index == e1.index;
}
};
std::vector<nnvm::ObjectPtr> var_nodes;
orig_entries->clear();
orig_entries->reserve(input_entries.size());
for (auto input_entry : input_entries) {
// If the node is a variable itself, we may want to skip the node.
if (input_entry->node->is_variable() && skip_var)
continue;
auto it = std::find_if(orig_entries->begin(), orig_entries->end(),
pred_entry(*input_entry));
bool exist = (it != orig_entries->end());
orig_entries->push_back(*input_entry);
nnvm::ObjectPtr n;
// If we haven't seen the entry before, we need to create a new var node
// for the node entry.
if (!exist) {
nnvm::Symbol sym;
sym.outputs.push_back(*input_entry);
n = nnvm::CreateVariableNode(sym.ListOutputNames()[0]);
} else {
// Otherwise, we use the var node created before.
size_t idx = it - orig_entries->begin();
CHECK_LT(idx, var_nodes.size());
n = var_nodes[idx];
}
var_nodes.push_back(n);
*input_entry = nnvm::NodeEntry{n, 0, 0};
}
return true;
}
} // namespace mxnet