blob: 6cb4a70324b5b25c2c4f5ebc8dddd5ee5a690d0b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./imperative_utils.h"
#include "./cached_op.h"
namespace mxnet {
namespace imperative {
inline std::vector<NDArray*> NodeInputs(const nnvm::IndexedGraph& idx,
const int node_idx,
const std::vector<NDArray*> arrays) {
const nnvm::IndexedGraph::Node& node = idx[node_idx];
const size_t num_inputs = node.inputs.size();
std::vector<NDArray*> ndinputs;
ndinputs.reserve(num_inputs);
for (const auto& j : node.inputs) {
size_t eid = idx.entry_id(j);
ndinputs.emplace_back(arrays[eid]);
}
return ndinputs;
}
inline std::vector<NDArray*> NodeOutputs(const nnvm::IndexedGraph& idx,
const int node_idx,
const std::vector<NDArray*> arrays) {
const nnvm::IndexedGraph::Node& node = idx[node_idx];
const size_t num_outputs = node.source->num_outputs();
std::vector<NDArray*> ndoutputs;
ndoutputs.reserve(num_outputs);
for (size_t j = 0; j < num_outputs; ++j) {
size_t eid = idx.entry_id(node_idx, j);
ndoutputs.emplace_back(arrays[eid]);
}
return ndoutputs;
}
inline std::vector<OpReqType> NodeReq(const nnvm::IndexedGraph& idx,
const int node_idx,
const std::vector<OpReqType> array_reqs) {
const nnvm::IndexedGraph::Node& node = idx[node_idx];
const size_t num_outputs = node.source->num_outputs();
std::vector<OpReqType> req;
req.reserve(num_outputs);
for (size_t j = 0; j < num_outputs; ++j) {
size_t eid = idx.entry_id(node_idx, j);
req.push_back(array_reqs[eid]);
}
return req;
}
inline void InvokeOperator(const nnvm::IndexedGraph& idx,
const int node_idx,
const bool retain_graph,
const std::vector<NDArray*> arrays,
Context ctx,
std::vector<OpStatePtr>* p_states,
std::vector<NDArray*> ndinputs,
std::vector<NDArray*> ndoutputs,
std::vector<OpReqType> *p_req,
std::vector<uint32_t> *p_ref_count,
std::function<void(const OpStatePtr &state)> invoke) {
static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
std::vector<OpStatePtr>& states = *p_states;
std::vector<OpReqType> &req = *p_req;
std::vector<uint32_t> &ref_count = *p_ref_count;
const nnvm::IndexedGraph::Node& node = idx[node_idx];
if (node.source->op() == bwd_cached_op) {
const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
nnvm::Node* fwd_node = node.source->control_deps[0].get();
auto fwd_node_id = idx.node_id(fwd_node);
cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs);
} else if (createop.count(node.source->op())) {
mxnet::ShapeVector arg_shapes;
nnvm::DTypeVector arg_dtypes;
arg_shapes.reserve(ndinputs.size());
arg_dtypes.reserve(ndinputs.size());
for (auto& ndinput : ndinputs) {
arg_shapes.emplace_back(ndinput->shape());
arg_dtypes.emplace_back(ndinput->dtype());
}
states[node_idx] = createop[node.source->op()](node.source->attrs, ctx, arg_shapes, arg_dtypes);
invoke(states[node_idx]);
} else if (is_layer_backward.get(node.source->op(), false)) {
nnvm::Node* fwd_node = node.source->control_deps[0].get();
auto fwd_node_id = idx.node_id(fwd_node);
invoke(states[fwd_node_id]);
} else {
invoke(OpStatePtr());
}
for (const auto& j : node.inputs) {
size_t eid = idx.entry_id(j);
--ref_count[eid];
if (ref_count[eid] == 0) {
*arrays[eid] = NDArray();
}
}
for (size_t j = 0; j < ndoutputs.size(); ++j) {
size_t eid = idx.entry_id(node_idx, j);
if (ref_count[eid] == 0) {
*arrays[eid] = NDArray();
}
}
}
void RunGraph(
const bool retain_graph,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr> *p_states,
const DispatchModeVector &dispatch_modes,
bool recording,
mxnet::ShapeVector *shapes) {
CHECK(shapes == nullptr);
for (size_t i = node_start; i < node_end; ++i) {
const nnvm::IndexedGraph::Node& node = idx[i];
if (node.source->op() == nullptr) {
continue;
}
std::vector<NDArray*> ndinputs = NodeInputs(idx, i, arrays);
std::vector<NDArray*> ndoutputs = NodeOutputs(idx, i, arrays);
std::vector<OpReqType> req = NodeReq(idx, i, array_reqs);
Context ctx = ndoutputs[0]->ctx();
auto invoke = [&](const OpStatePtr &state) {
const nnvm::IndexedGraph::Node& node = idx[i];
DispatchMode dispatch_mode = dispatch_modes[i];
Imperative::Get()->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
req, dispatch_mode, state);
if (recording) {
Imperative::Get()->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, state);
}
};
InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs,
&req, &ref_count, invoke);
}
}
void NaiveRunGraph(
const bool retain_graph,
const Context& default_ctx,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr> *p_states,
const DispatchModeVector &dispatch_modes,
bool recording,
mxnet::ShapeVector *shapes) {
for (size_t i = node_start; i < node_end; ++i) {
const nnvm::IndexedGraph::Node& node = idx[i];
if (node.source->op() == nullptr) {
continue;
}
std::vector<NDArray*> ndinputs = NodeInputs(idx, i, arrays);
std::vector<NDArray*> ndoutputs = NodeOutputs(idx, i, arrays);
std::vector<OpReqType> req;
Context ctx = GetContext(node.source->attrs, ndinputs, ndoutputs, default_ctx);
auto invoke = [&](const OpStatePtr &state) {
const nnvm::IndexedGraph::Node& node = idx[i];
DispatchMode dispatch_mode = DispatchMode::kUndefined;
SetShapeType(ctx, node.source->attrs, ndinputs, ndoutputs, &dispatch_mode);
SetWriteInplaceReq(ndinputs, ndoutputs, &req);
Imperative::Get()->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
req, dispatch_mode, state);
for (size_t j = 0; j < ndoutputs.size(); ++j) {
if (ndoutputs[j]->shape().ndim() == 0) {
ndoutputs[j]->WaitToRead();
ndoutputs[j]->SetShapeFromChunk();
}
size_t eid = idx.entry_id(i, j);
auto shape = ndoutputs[j]->shape();
(*shapes)[eid] = shape;
}
if (recording) {
Imperative::Get()->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, state);
}
};
InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs,
&req, &ref_count, invoke);
}
}
} // namespace imperative
} // namespace mxnet