blob: 4f4dfdcc14dd5fcb18f34a038fddbdcc315dd443 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef MXNET_IMPERATIVE_CACHED_OP_H_
#define MXNET_IMPERATIVE_CACHED_OP_H_
#include <mxnet/imperative.h>
#include <vector>
#include <atomic>
#include <utility>
#include <string>
#include <unordered_map>
namespace mxnet {
/*! \brief CachedOp Parameters */
struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
uint32_t inline_limit;
uint32_t forward_bulk_size;
uint32_t backward_bulk_size;
bool static_alloc;
bool static_shape;
nnvm::Tuple<uint32_t> data_indices;
nnvm::Tuple<uint32_t> param_indices;
DMLC_DECLARE_PARAMETER(CachedOpConfig) {
DMLC_DECLARE_FIELD(static_alloc)
.set_default(false)
.describe("Statically allocate memory to improve speed. "
"Memory usage may increase.");
DMLC_DECLARE_FIELD(static_shape)
.set_default(false)
.describe("Optimize for invariant input shapes between iterations. "
"Must also set static_alloc to True. "
"Change of input shapes is still allowed but slower.");
DMLC_DECLARE_FIELD(inline_limit)
.set_default(2)
.describe("Maximum number of operators that can be inlined.");
DMLC_DECLARE_FIELD(forward_bulk_size)
.set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
.describe("Segment size of bulk execution during forward pass.");
DMLC_DECLARE_FIELD(backward_bulk_size)
.set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
.describe("Segment size of bulk execution during backward pass.");
DMLC_DECLARE_FIELD(data_indices)
.set_default(nnvm::Tuple<uint32_t>())
.describe("Position of argument variables.");
DMLC_DECLARE_FIELD(param_indices)
.set_default(nnvm::Tuple<uint32_t>())
.describe("Position of parameters.");
}
};
class CachedOp {
public:
CachedOp(
const nnvm::Symbol& sym,
const std::vector<std::pair<std::string, std::string> >& flags);
~CachedOp();
uint32_t num_inputs() const {
return fwd_graph_.indexed_graph().input_nodes().size();
}
uint32_t num_outputs() const {
return fwd_graph_.outputs.size();
}
uint32_t num_backward_inputs() const {
return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
}
std::vector<bool>& save_inputs() {
return save_inputs_;
}
std::vector<bool>& save_outputs() {
return save_outputs_;
}
const std::unordered_set<uint32_t>& mutable_input_nodes() const {
return fwd_graph_.indexed_graph().mutable_input_nodes();
}
std::vector<nnvm::NodeEntry> Gradient(
const nnvm::NodePtr& node,
const std::vector<nnvm::NodeEntry>& ograds) const;
OpStatePtr Forward(
const std::shared_ptr<CachedOp>& op_ptr,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs);
void Backward(
const bool retain_graph,
const OpStatePtr& state,
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
// forward storage type inference
bool ForwardStorageType(
const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
// backward storage type inference
bool BackwardStorageType(
const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
private:
struct GraphInfo;
struct DynamicRuntime;
struct CachedOpState;
OpStatePtr GetCachedOpState(const Context& ctx);
bool SetForwardGraph(
GraphInfo* info,
const bool recording,
const std::vector<NDArray*>& inputs);
bool SetBackwardGraph(
GraphInfo* info,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& inputs,
bool detect_inplace_addto = false);
OpStatePtr DynamicForward(
const Context& default_ctx,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs);
void DynamicBackward(
const bool retain_graph,
const OpStatePtr& op_state,
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
void StaticAllocMemory(
const OpStatePtr& state_ptr,
bool recording,
bool keep_fwd);
void StaticInitExec(
const OpStatePtr& state_ptr,
bool recording,
bool keep_fwd);
void StaticRunOps(
const Context& default_ctx,
const nnvm::Graph& g,
const OpStatePtr& state_ptr,
const std::vector<NDArray *> &state_arrays,
size_t start_nid,
size_t end_nid);
OpStatePtr StaticForward(
const Context& default_ctx,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs);
void StaticBackward(
const bool retain_graph,
const OpStatePtr& state_ptr,
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
CachedOpConfig config_;
nnvm::Graph fwd_graph_;
nnvm::Graph grad_graph_;
nnvm::Graph full_graph_;
bool inlining_;
std::vector<nnvm::NodeEntry> ograd_entries_;
std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
std::unordered_map<uint32_t, uint32_t> fwd_input_to_grad_output_;
std::vector<bool> save_inputs_, save_outputs_;
std::vector<OpReqType> bwd_output_reqs_;
std::mutex mutex_;
std::unordered_map<Context, std::vector<OpStatePtr> > cached_op_states_;
};
using CachedOpPtr = std::shared_ptr<CachedOp>;
} // namespace mxnet
#endif // MXNET_IMPERATIVE_CACHED_OP_H_