[backport 1.5.x]Fix Cached_op with static_shape=true (#15298) (#15380)
* Fix Cached_op with static_shape=true (#15298)
* Fix
* run ci
* trigger
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index d7e1543..efe3801 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -81,6 +81,7 @@
std::vector<NDArray> buff;
std::vector<NDArray*> arrays;
+ std::vector<NDArray*> arrays_with_in_out;
std::vector<OpReqType> array_reqs;
std::vector<OpStatePtr> op_states;
@@ -762,7 +763,8 @@
// We are going to add input and output arrays to the array list.
// The input and output arrays should only be valid for this run,
// so we shouldn't modify the state's array list.
- auto arrays = state.arrays;
+ state.arrays_with_in_out = state.arrays;
+ auto& arrays = state.arrays_with_in_out;
if (config_.static_shape) {
for (auto i : config_.param_indices) {
auto nid = idx.input_nodes()[i];
@@ -1063,7 +1065,8 @@
// We are going to add input and output arrays to the array list.
// The input and output arrays should only be valid for this run,
// so we shouldn't modify the state's array list.
- auto arrays = state.arrays;
+ state.arrays_with_in_out = state.arrays;
+ auto& arrays = state.arrays_with_in_out;
for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
auto eid = state.info.bwd_input_eid[i];
if (eid == kEidNotExist) {
diff --git a/src/nnvm/legacy_op_util.cc b/src/nnvm/legacy_op_util.cc
index 698666f..3e03b6b 100644
--- a/src/nnvm/legacy_op_util.cc
+++ b/src/nnvm/legacy_op_util.cc
@@ -79,7 +79,6 @@
public:
OperatorState(Operator *opr, const OperatorProperty *prop) {
opr_ = opr;
- fwd_init_ = bwd_init_ = false;
in_data_fwd_.resize(prop->ListArguments().size());
in_data_bwd_.resize(prop->ListArguments().size());
@@ -110,19 +109,16 @@
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
- if (!fwd_init_) {
- CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
- CHECK_EQ(outputs.size(), out_data_.size());
- // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
- // referred by arg_data_ptr_ will be overriden
- for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i];
- for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i];
- for (size_t i = 0; i < aux_data_.size(); ++i) {
- aux_data_[i] = inputs[i + in_data_fwd_.size()];
- }
- for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
- fwd_init_ = true;
+ CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
+ CHECK_EQ(outputs.size(), out_data_.size());
+ // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
+ // referred by arg_data_ptr_ will be overriden
+ for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i];
+ for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i];
+ for (size_t i = 0; i < aux_data_.size(); ++i) {
+ aux_data_[i] = inputs[i + in_data_fwd_.size()];
}
+ for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
opr_->Forward(ctx, in_data_fwd_, req, out_data_, aux_data_);
}
@@ -130,27 +126,22 @@
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
- if (!bwd_init_) {
- CHECK(fwd_init_);
- CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size());
- // override tblobs pointed by arg_data_ptr_ since they might not contain
- // initialized data during forward pass.
- for (size_t i = 0; i < arg_data_ptr_.size(); ++i) {
- *arg_data_ptr_[i] = inputs[i];
- }
- for (size_t i = 0; i < aux_data_.size(); ++i) {
- aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i];
- }
- CHECK_EQ(outputs.size(), in_grad_.size());
- for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i];
- bwd_init_ = true;
+ CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size());
+ // override tblobs pointed by arg_data_ptr_ since they might not contain
+ // initialized data during forward pass.
+ for (size_t i = 0; i < arg_data_ptr_.size(); ++i) {
+ *arg_data_ptr_[i] = inputs[i];
}
+ for (size_t i = 0; i < aux_data_.size(); ++i) {
+ aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i];
+ }
+ CHECK_EQ(outputs.size(), in_grad_.size());
+ for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i];
opr_->Backward(ctx, out_grad_, in_data_bwd_, out_data_, req, in_grad_, aux_data_);
}
private:
Operator *opr_;
- bool fwd_init_, bwd_init_;
// input data blobs for forward and backward
// in_data_fwd_ and in_data_bwd_ will hold different tblobs when StorageFallbackOpExecutor
// performs storage fallback on a non-default input NDArray. The one in in_data_fwd_ is