blob: c84a3b9be502a9b61af7663020993f29e5093db0 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./imperative_utils.h"
#include "./cached_op.h"
namespace mxnet {
namespace imperative {
void RunGraph(
const bool retain_graph,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr> *p_states,
const DispatchModeVector &dispatch_modes,
bool recording) {
using namespace nnvm;
using namespace imperative;
static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
const auto imp = Imperative::Get();
std::vector<OpStatePtr>& states = *p_states;
std::vector<NDArray*> ndinputs, ndoutputs;
ShapeVector arg_shapes;
DTypeVector arg_dtypes;
std::vector<OpReqType> req;
for (size_t i = node_start; i < node_end; ++i) {
const nnvm::IndexedGraph::Node& node = idx[i];
if (node.source->op() == nullptr) continue;
auto num_outputs = node.source->num_outputs();
ndinputs.clear();
ndinputs.reserve(node.inputs.size());
for (const auto& j : node.inputs) {
ndinputs.emplace_back(arrays[idx.entry_id(j)]);
CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index;
}
ndoutputs.clear();
ndoutputs.reserve(num_outputs);
req.clear();
req.reserve(num_outputs);
for (size_t j = 0; j < num_outputs; ++j) {
size_t eid = idx.entry_id(i, j);
ndoutputs.emplace_back(arrays[eid]);
req.push_back(array_reqs[eid]);
CHECK(array_reqs[eid] == kNullOp || !ndoutputs.back()->is_none());
}
const Context& ctx = ndoutputs[0]->ctx();
const DispatchMode dispatch_mode = dispatch_modes[i];
if (node.source->op() == bwd_cached_op) {
const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
nnvm::Node* fwd_node = node.source->control_deps[0].get();
auto fwd_node_id = idx.node_id(fwd_node);
cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs);
} else if (createop.count(node.source->op())) {
arg_shapes.clear();
arg_dtypes.clear();
arg_shapes.reserve(ndinputs.size());
arg_dtypes.reserve(ndinputs.size());
for (size_t i = 0; i < ndinputs.size(); ++i) {
arg_shapes.emplace_back(ndinputs[i]->shape());
arg_dtypes.emplace_back(ndinputs[i]->dtype());
}
states[i] = createop[node.source->op()](
node.source->attrs, ctx, arg_shapes, arg_dtypes);
imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]);
if (recording) {
imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]);
}
} else if (is_layer_backward.get(node.source->op(), false)) {
nnvm::Node* fwd_node = node.source->control_deps[0].get();
auto fwd_node_id = idx.node_id(fwd_node);
imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
req, dispatch_mode, states[fwd_node_id]);
if (recording) {
imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]);
}
} else {
imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode);
if (recording) {
imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs);
}
}
for (const auto& j : node.inputs) {
size_t eid = idx.entry_id(j);
--ref_count[eid];
if (ref_count[eid] == 0) *arrays[eid] = NDArray();
}
for (size_t j = 0; j < ndoutputs.size(); ++j) {
size_t eid = idx.entry_id(i, j);
if (ref_count[eid] == 0) *arrays[eid] = NDArray();
}
}
}
} // namespace imperative
} // namespace mxnet