| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file infer_graph_attr_pass.cc |
| * \brief infer graph shape, dtype, and storage type |
| */ |
| |
| #include <mxnet/op_attr_types.h> |
| #include <mxnet/graph_attr_types.h> |
| #include <mxnet/imperative.h> |
| #include "./exec_pass.h" |
| #include "../operator/operator_common.h" |
| #include "../common/exec_utils.h" |
| |
| namespace mxnet { |
| namespace exec { |
| |
| template<typename AttrType, typename FInfer> |
| bool ApplyOpInferAttr(const nnvm::Graph& g, |
| const FInfer& finfer, |
| const NodeAttrs& attrs, |
| const uint32_t nid, |
| std::vector<AttrType>* in_attrs, |
| std::vector<AttrType>* out_attrs, |
| DispatchMode* dispatch_mode) { |
| return finfer(attrs, in_attrs, out_attrs); |
| } |
| |
| template<> |
| bool ApplyOpInferAttr<int, FInferStorageType>(const nnvm::Graph& g, |
| const FInferStorageType& finfer, |
| const NodeAttrs& attrs, |
| const uint32_t nid, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs, |
| DispatchMode* dispatch_mode) { |
| const DevMaskVector& dev_masks = g.GetAttr<DevMaskVector>("dev_mask"); |
| const bool success = finfer(attrs, dev_masks[nid], dispatch_mode, in_attrs, out_attrs); |
| if (!success) { |
| LOG(FATAL) << "Operator not implemented: " |
| << common::operator_stype_string(attrs, dev_masks[nid], *in_attrs, *out_attrs); |
| } |
| if (*dispatch_mode == DispatchMode::kFComputeFallback) { |
| common::LogStorageFallback(attrs, dev_masks[nid], in_attrs, out_attrs); |
| } |
| return true; |
| } |
| |
| template<typename AttrType, typename IsNone> |
| inline void GetAttrFromForwardNode(const uint32_t nid, |
| const nnvm::IndexedGraph &idx, |
| std::vector<AttrType>* rshape_ptr, |
| std::vector<bool>* inference_finished, |
| IsNone fis_none) { |
| std::vector<AttrType>& rshape = *rshape_ptr; |
| const nnvm::IndexedGraph::Node& inode = idx[nid]; |
| // gradient function, used to get node correspondence. |
| static auto& fgrad = |
| Op::GetAttr<nnvm::FGradient>("FGradient"); |
| nnvm::ObjectPtr fwd_ptr = inode.source->control_deps[0]; |
| const nnvm::IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; |
| // use gradient function to find out the correspondence. |
| std::vector<nnvm::NodeEntry> ograd(fwd_ptr->num_outputs()); |
| for (size_t i = 0; i < ograd.size(); ++i) { |
| ograd[i].index = static_cast<uint32_t>(i); |
| } |
| // input gradient list |
| const std::vector<nnvm::NodeEntry>& igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); |
| const nnvm::Node* igrad_node = nullptr; |
| bool all_attrs_known = true; |
| // Input gradient assignement |
| for (size_t i = 0; i < igrad.size(); ++i) { |
| if (igrad[i].node->op() == inode.source->op()) { |
| uint32_t eid = idx.entry_id(nid, igrad[i].index); |
| if (fis_none(rshape[idx.entry_id(fnode.inputs[i])])) { |
| // Need to skip empty forward shape, because it may not be |
| // available now and it is possible to infer the forward |
| // shape in one of the next a few passes |
| all_attrs_known = false; |
| } else { |
| if (fis_none(rshape[eid])) { |
| rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; |
| } else { |
| CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) |
| << "Backward shape inconsistent with the forward shape"; |
| } |
| } |
| if (igrad_node == nullptr) { |
| igrad_node = igrad[i].node.get(); |
| } else { |
| CHECK(igrad_node == igrad[i].node.get()); |
| } |
| } |
| } |
| // out grad entries |
| CHECK(igrad_node != nullptr) |
| << "Cannot find matching backward op for " << inode.source->attrs.name; |
| for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { |
| const nnvm::NodeEntry& e = igrad_node->inputs[i]; |
| if (e.node == nullptr) { |
| uint32_t eid = idx.entry_id(inode.inputs[i]); |
| if (fis_none(rshape[eid])) { |
| rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)]; |
| } |
| if (fis_none(rshape[eid])) { |
| // If the attr is still unknown |
| all_attrs_known = false; |
| } |
| } |
| } |
| (*inference_finished)[nid] = all_attrs_known; |
| } |
| |
| template<typename FAccessSubgraphType, typename AttrType, typename IsNone> |
| void GetAttrFromFusedNode(uint32_t nid, |
| const nnvm::IndexedGraph& idx, |
| std::vector<AttrType>* rshape_ptr, |
| std::vector<bool>* inference_finished, |
| IsNone fis_none, |
| const std::string& infer_fusion_name) { |
| std::vector<AttrType>& rshape = *rshape_ptr; |
| const auto& inode = idx[nid]; |
| // gradient function, used to get node correspondence. |
| static auto& fgrad = |
| Op::GetAttr<nnvm::FGradient>("FGradient"); |
| nnvm::ObjectPtr fused_fwd_ptr = inode.source->control_deps[0]; |
| static auto& finfer_fused_shape = |
| Op::GetAttr<FAccessSubgraphType>(infer_fusion_name); |
| auto finfer = finfer_fused_shape.get(fused_fwd_ptr->op(), nullptr); |
| CHECK(finfer != nullptr) << "Operator " << fused_fwd_ptr->attrs.name << |
| " is marked as Fusion but does not allow accessing attributes"; |
| const auto& inferred_attrs = finfer(fused_fwd_ptr->attrs); |
| const auto& fwd_ptr = std::get<0>(inferred_attrs); |
| const auto& input_attrs = std::get<1>(inferred_attrs); |
| const auto& output_attrs = std::get<2>(inferred_attrs); |
| |
| // use gradient function to find out the correspondence. |
| std::vector<nnvm::NodeEntry> ograd(fwd_ptr->num_outputs()); |
| for (size_t i = 0; i < ograd.size(); ++i) { |
| ograd[i].index = static_cast<uint32_t>(i); |
| } |
| // input gradient list |
| const std::vector<nnvm::NodeEntry>& igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); |
| const nnvm::Node* igrad_node = nullptr; |
| bool all_attrs_known = true; |
| // Set the attributes of output gradients |
| // using attributes of forward node inputs |
| for (size_t i = 0; i < igrad.size(); ++i) { |
| if (igrad[i].node->op() == inode.source->op()) { |
| uint32_t eid = idx.entry_id(nid, igrad[i].index); |
| if (fis_none(input_attrs[i])) { |
| // Need to skip empty forward shape, because it may not be |
| // available now and it is possible to infer the forward |
| // shape in one of the next a few passes |
| all_attrs_known = false; |
| } else { |
| if (fis_none(rshape[eid])) { |
| rshape[eid] = input_attrs[i]; |
| } else { |
| CHECK_EQ(rshape[eid], input_attrs[i]) |
| << "Backward shape inconsistent with the forward shape"; |
| } |
| } |
| if (igrad_node == nullptr) { |
| igrad_node = igrad[i].node.get(); |
| } else { |
| CHECK(igrad_node == igrad[i].node.get()); |
| } |
| } |
| } |
| |
| // Set the attributes of input gradients |
| // using attributes of forward node outputs |
| CHECK(igrad_node != nullptr) |
| << "Cannot find matching backward op for " << inode.source->attrs.name; |
| for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { |
| const nnvm::NodeEntry& e = igrad_node->inputs[i]; |
| if (e.node == nullptr) { |
| uint32_t eid = idx.entry_id(inode.inputs[i]); |
| if (fis_none(rshape[eid])) { |
| rshape[eid] = output_attrs[e.index]; |
| } |
| if (fis_none(rshape[eid])) { |
| // If the attr is still unknown |
| all_attrs_known = false; |
| } |
| } |
| } |
| (*inference_finished)[nid] = all_attrs_known; |
| } |
| |
| template <typename FProvideSubgraphType, typename AttrType> |
| void ProvideAttrToFusion(const uint32_t nid, |
| const nnvm::IndexedGraph& idx, |
| const std::vector<AttrType>& rshape, |
| const std::string& provide_fusion_name) { |
| const auto& inode = idx[nid]; |
| std::vector<std::vector<AttrType>> in_attrs; |
| std::vector<std::vector<AttrType>> out_attrs; |
| for (const auto& dep_node : inode.source->control_deps) { |
| in_attrs.push_back({}); |
| out_attrs.push_back({}); |
| auto ¤t_in_attrs = in_attrs.back(); |
| auto ¤t_out_attrs = out_attrs.back(); |
| uint32_t dep_node_id = idx.node_id(dep_node.get()); |
| for (const auto& e : idx[dep_node_id].inputs) { |
| current_in_attrs.push_back(rshape[idx.entry_id(e)]); |
| } |
| for (size_t i = 0; i < dep_node->num_outputs(); ++i) { |
| current_out_attrs.push_back(rshape[idx.entry_id(dep_node_id, i)]); |
| } |
| } |
| auto provide = |
| Op::GetAttr<FProvideSubgraphType>(provide_fusion_name).get(inode.source->op(), nullptr); |
| CHECK(provide != nullptr) << |
| "Encountered Fusion operator that does not implement providing subgraph attr " << |
| provide_fusion_name << "."; |
| provide(inode.source->attrs, inode.source->control_deps, in_attrs, out_attrs); |
| } |
| |
| /*!\brief |
| * This is a duplicate of the InferAttr function in nnvm with minor modification |
| * to support inferring storage type whose function signature is different from |
| * shape/type inference functions'. The nnvm InferAttr will be deprecated |
| * in the future. Please use interfaces InferShape, InferType, and InferStorageType |
| * to call this function. |
| * |
| * \param ret graph used for attribute inference |
| * \param emmpty_val empty value of the attribute |
| * \param infer_name name of the function used for attribute inference |
| * \param infer_fusion_name name of the function used for accessing attributes in fused nodes |
| * \param input_name name of the attribute in the graph used to store the |
| * input data for attribute inference |
| * \param attr_key_name name of the attribute used for inference for variable nodes |
| * \param attr_name name of the inferred attribute |
| * \param unknown_name name of the attribute storing number of entries |
| * impossible to infer |
| * \param fis_none function returning true for not fully inferred values |
| * \param fdefault default function used for inference if the node does not |
| * provide its own implementation. |
| * \param bwd_identity_assign whether the attributes of forward NDArray and backward |
| * NDArray have to be the same. False only for storage |
| * type inference |
| * \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for |
| * storage type inference |
| * \param default_mode_val default value of the dispatch mode attribute on the node. Used |
| * for storage type inference |
| */ |
| template<typename AttrType, typename FInferType, typename FAccessSubgraphType, |
| typename FProvideSubgraphType, typename IsNone, typename FDefault> |
| nnvm::Graph InferAttr(nnvm::Graph &&ret, |
| const AttrType empty_val, |
| const char* infer_name, |
| const char* infer_fusion_name, |
| const char* provide_fusion_name, |
| const char* input_name, |
| const char* attr_key_name, |
| const char* attr_name, |
| const char* unknown_name, |
| IsNone fis_none, |
| FDefault fdefault, |
| bool bwd_identity_assign, |
| const char* dispatch_mode_name, |
| const DispatchMode default_mode_val = DispatchMode::kUndefined) { |
| using nnvm::IndexedGraph; |
| using nnvm::Op; |
| using AttrVector = std::vector<AttrType>; |
| using NodeAttrVector = std::vector<DispatchMode>; |
| using dmlc::any; |
| |
| const IndexedGraph& idx = ret.indexed_graph(); |
| static auto& finfer_shape = |
| Op::GetAttr<FInferType>(infer_name); |
| static auto& is_backward = |
| Op::GetAttr<nnvm::TIsBackward>("TIsBackward"); |
| // reshape shape vector |
| AttrVector rshape; |
| // vector holding information which operators |
| // finished attribute inference |
| std::vector<bool> inference_finished(idx.num_nodes(), false); |
| // dispatch mode vector |
| DispatchModeVector dispatch_modes; |
| if (ret.attrs.count(attr_name) != 0) { |
| rshape = ret.MoveCopyAttr<AttrVector>(attr_name); |
| } else { |
| rshape.resize(idx.num_node_entries(), empty_val); |
| } |
| |
| if (ret.attrs.count(input_name) != 0) { |
| const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name); |
| CHECK_LE(shape_args.size(), idx.input_nodes().size()) |
| << "More provided " << attr_name << "s than number of arguments."; |
| for (size_t i = 0; i < shape_args.size(); ++i) { |
| rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i]; |
| } |
| } |
| |
| // get the shape hints |
| std::string shape_hints_key = std::string(attr_name) + "_hints"; |
| if (ret.attrs.count(shape_hints_key)) { |
| nnvm::NodeEntryMap<AttrType> shape_hints = |
| ret.GetAttr<nnvm::NodeEntryMap<AttrType>>(shape_hints_key); |
| for (const auto& kv : shape_hints) { |
| nnvm::NodeEntry e = kv.first; |
| if (idx.exist(e.node.get())) { |
| rshape[idx.entry_id(kv.first)] = kv.second; |
| } |
| } |
| } |
| |
| std::string shape_attr_key; |
| if (ret.attrs.count(attr_key_name) != 0) { |
| shape_attr_key = ret.GetAttr<std::string>(attr_key_name); |
| // erase the provided arguments |
| ret.attrs.erase(attr_key_name); |
| } |
| |
| // limit inference to part of the graph |
| uint32_t node_start = 0, node_end = idx.num_nodes(); |
| if (ret.attrs.count("node_range")) { |
| const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("node_range"); |
| node_start = range.first; |
| node_end = range.second; |
| CHECK_GE(node_start, 0); |
| CHECK_LE(node_end, idx.num_nodes()); |
| ret.attrs.erase("node_range"); |
| } |
| uint32_t entry_start = 0, entry_end = idx.num_node_entries(); |
| if (ret.attrs.count("entry_range")) { |
| const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("entry_range"); |
| entry_start = range.first; |
| entry_end = range.second; |
| CHECK_GE(entry_start, 0); |
| CHECK_LE(entry_end, idx.num_node_entries()); |
| ret.attrs.erase("entry_range"); |
| } |
| // populate the node attribute vector |
| if (dispatch_mode_name != nullptr) { |
| if (ret.attrs.count(dispatch_mode_name) != 0) { |
| dispatch_modes = ret.MoveCopyAttr<NodeAttrVector>(dispatch_mode_name); |
| } else { |
| LOG(FATAL) << "Node attribute " << dispatch_mode_name << " does not exist in the graph"; |
| } |
| } |
| |
| // Temp space for shape inference. |
| std::vector<AttrType> ishape, oshape; |
| |
| // inference step function for nid |
| auto infer_step = [&](uint32_t nid, bool last_iter) { |
| if (inference_finished[nid]) return; |
| const auto& inode = idx[nid]; |
| const uint32_t num_inputs = inode.inputs.size(); |
| const uint32_t num_outputs = inode.source->num_outputs(); |
| if (inode.source->is_variable()) { |
| // Variable node. No operator. Only one output entry. |
| CHECK(inode.source->op() == nullptr); |
| CHECK_EQ(num_outputs, 1U); |
| const uint32_t out_ent_id = idx.entry_id(nid, 0); |
| if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) { |
| auto it = inode.source->attrs.dict.find(shape_attr_key); |
| if (it != inode.source->attrs.dict.end()) { |
| std::istringstream is(it->second); |
| CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; |
| } |
| } |
| if (!fis_none(rshape[out_ent_id])) { |
| inference_finished[nid] = true; |
| } |
| // assign a default value to node attribute |
| if (dispatch_mode_name != nullptr) { |
| op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val); |
| } |
| } else if (is_backward.get(inode.source->op(), false) && |
| inode.source->control_deps.size() && bwd_identity_assign) { |
| CHECK(dispatch_mode_name == nullptr) |
| << "Backward inference for node attributes is not available"; |
| CHECK_GE(inode.source->control_deps.size(), 1U) |
| << "BackwardOp need to have control_deps to its forward op"; |
| nnvm::ObjectPtr fwd_ptr = inode.source->control_deps[0]; |
| CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; |
| |
| static auto& is_fusion_helper = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper"); |
| if (!is_fusion_helper.get(fwd_ptr->op(), false)) { |
| GetAttrFromForwardNode(nid, idx, &rshape, &inference_finished, fis_none); |
| } else { |
| GetAttrFromFusedNode<FAccessSubgraphType>(nid, idx, &rshape, &inference_finished, |
| fis_none, infer_fusion_name); |
| } |
| } else { |
| DispatchMode* dispatch_mode = nullptr; |
| // Forward operator inference. |
| ishape.resize(num_inputs, empty_val); |
| for (uint32_t i = 0; i < ishape.size(); ++i) { |
| ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; |
| } |
| oshape.resize(num_outputs, empty_val); |
| for (uint32_t i = 0; i < oshape.size(); ++i) { |
| oshape[i] = rshape[idx.entry_id(nid, i)]; |
| } |
| if (dispatch_mode_name != nullptr) { |
| dispatch_mode = &dispatch_modes[nid]; |
| } |
| auto finfer = finfer_shape.get(inode.source->op(), fdefault); |
| if (finfer != nullptr) { |
| // Call inference function of the operator. |
| try { |
| static auto& is_fusion = Op::GetAttr<exec::TIsFusion>("TIsFusion"); |
| if (is_fusion.get(inode.source->op(), false)) { |
| ProvideAttrToFusion<FProvideSubgraphType>(nid, idx, rshape, provide_fusion_name); |
| } |
| ApplyOpInferAttr(ret, finfer, inode.source->attrs, |
| nid, &ishape, &oshape, dispatch_mode); |
| bool finished = true; |
| for (const auto& attr : ishape) { |
| if (fis_none(attr)) finished = false; |
| } |
| for (const auto& attr : oshape) { |
| if (fis_none(attr)) finished = false; |
| } |
| inference_finished[nid] = finished; |
| } catch (const std::exception& e) { |
| throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); |
| } |
| } else { |
| // Operator does not provide sttribute inference function, |
| // so we need to test if everything was inferred by other operators |
| bool all_attrs_known = true; |
| for (const auto& attr : ishape) { |
| if (fis_none(attr)) { |
| all_attrs_known = false; |
| } |
| } |
| for (const auto& attr : oshape) { |
| if (fis_none(attr)) { |
| all_attrs_known = false; |
| } |
| } |
| inference_finished[nid] = all_attrs_known; |
| if (!all_attrs_known) { |
| CHECK(!last_iter) |
| << "Attribute " << infer_name |
| << " is not registered by op " << inode.source->op()->name |
| << ". We are not able to complete the inference because of this"; |
| } |
| } |
| // Save to the result map. |
| for (uint32_t i = 0; i < num_inputs; ++i) { |
| rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; |
| } |
| for (uint32_t i = 0; i < num_outputs; ++i) { |
| rshape[idx.entry_id(nid, i)] = oshape[i]; |
| } |
| } |
| }; |
| |
| size_t last_num_unknown; |
| size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0; |
| size_t num_unknown_entry_attr = entry_end - entry_start; |
| size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode; |
| bool last_iter = false; |
| bool do_next_iteration = true; |
| int i = 0; |
| do { |
| if (i % 2 == 0) { |
| for (uint32_t nid = node_start; nid < node_end; ++nid) { |
| infer_step(nid, last_iter); |
| } |
| } else { |
| // backward inference |
| for (uint32_t i = node_end; i != node_start; --i) { |
| infer_step(i - 1, last_iter); |
| } |
| } |
| last_num_unknown = num_unknown; |
| num_unknown = 0; |
| for (size_t j = entry_start; j < entry_end; ++j) { |
| if (fis_none(rshape[j])) { |
| ++num_unknown; |
| } |
| } |
| if (dispatch_mode_name) { |
| for (size_t i = node_start; i < node_end; i++) { |
| if (dispatch_modes[i] == DispatchMode::kUndefined) ++num_unknown; |
| } |
| } |
| do_next_iteration = num_unknown > 0 && last_num_unknown > num_unknown; |
| if (!do_next_iteration && !last_iter) { |
| // Check if every op agrees that it should be |
| // the end of attribute inference. If not, |
| // perform one final step |
| for (const bool done : inference_finished) { |
| do_next_iteration = do_next_iteration || !done; |
| } |
| last_iter = true; |
| } |
| ++i; |
| } while (do_next_iteration); |
| // set the shapes |
| ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape)); |
| // set the shapes |
| if (dispatch_mode_name) { |
| ret.attrs[dispatch_mode_name] = std::make_shared<any>(std::move(dispatch_modes)); |
| } |
| // number of nodes who knows the shape. |
| ret.attrs[unknown_name] = std::make_shared<any>(num_unknown); |
| return ret; |
| } |
| |
| /*!\brief |
| * This is a version of the InferAttr function specifically for shape inference. |
| * |
| * \param ret graph used for attribute inference |
| * \param emmpty_val empty value of the attribute |
| * \param infer_name name of the function used for attribute inference |
| * \param input_name name of the attribute in the graph used to store the |
| * input data for attribute inference |
| * \param attr_key_name name of the attribute used for inference for variable nodes |
| * \param attr_name name of the inferred attribute |
| * \param unknown_name name of the attribute storing number of entries |
| * impossible to infer |
| * \param fis_none function returning true for not fully inferred values |
| * \param fnum_unknown function returning how many elements are unknown in |
| * partially inferred value of the attribute |
| * \param fdefault default function used for inference if the node does not |
| * provide its own implementation. |
| * \param bwd_identity_assign whether the attributes of forward NDArray and backward |
| * NDArray have to be the same. False only for storage |
| * type inference |
| * \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for |
| * storage type inference |
| * \param default_mode_val default value of the dispatch mode attribute on the node. Used |
| * for storage type inference |
| */ |
| template<typename IsNone, typename FDefault, typename FNumUnknown> |
| nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, |
| const mxnet::TShape empty_val, |
| const char* infer_name, |
| const char* input_name, |
| const char* attr_key_name, |
| const char* attr_name, |
| const char* unknown_name, |
| IsNone fis_none, |
| FNumUnknown fnum_unknown, |
| FDefault fdefault, |
| bool bwd_identity_assign, |
| const char* dispatch_mode_name, |
| const DispatchMode default_mode_val = DispatchMode::kUndefined) { |
| using nnvm::IndexedGraph; |
| using nnvm::Op; |
| using AttrType = mxnet::TShape; |
| using FInferType = mxnet::FInferShape; |
| using AttrVector = std::vector<AttrType>; |
| using NodeAttrVector = std::vector<DispatchMode>; |
| using dmlc::any; |
| const IndexedGraph& idx = ret.indexed_graph(); |
| static auto& finfer_shape = |
| Op::GetAttr<FInferType>(infer_name); |
| static auto& is_backward = |
| Op::GetAttr<nnvm::TIsBackward>("TIsBackward"); |
| // reshape shape vector |
| AttrVector rshape; |
| // vector holding information which operators |
| // finished attribute inference |
| std::vector<bool> inference_finished(idx.num_nodes(), false); |
| // dispatch mode vector |
| DispatchModeVector dispatch_modes; |
| if (ret.attrs.count(attr_name) != 0) { |
| rshape = ret.MoveCopyAttr<AttrVector>(attr_name); |
| } else { |
| rshape.resize(idx.num_node_entries(), empty_val); |
| } |
| |
| if (ret.attrs.count(input_name) != 0) { |
| const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name); |
| CHECK_LE(shape_args.size(), idx.input_nodes().size()) |
| << "More provided " << attr_name << "s than number of arguments."; |
| for (size_t i = 0; i < shape_args.size(); ++i) { |
| rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i]; |
| } |
| } |
| |
| // get the shape hints |
| std::string shape_hints_key = std::string(attr_name) + "_hints"; |
| if (ret.attrs.count(shape_hints_key)) { |
| nnvm::NodeEntryMap<AttrType> shape_hints = |
| ret.GetAttr<nnvm::NodeEntryMap<AttrType>>(shape_hints_key); |
| for (const auto& kv : shape_hints) { |
| nnvm::NodeEntry e = kv.first; |
| if (idx.exist(e.node.get())) { |
| rshape[idx.entry_id(kv.first)] = kv.second; |
| } |
| } |
| } |
| |
| std::string shape_attr_key; |
| if (ret.attrs.count(attr_key_name) != 0) { |
| shape_attr_key = ret.GetAttr<std::string>(attr_key_name); |
| // erase the provided arguments |
| ret.attrs.erase(attr_key_name); |
| } |
| |
| // limit inference to part of the graph |
| uint32_t node_start = 0, node_end = idx.num_nodes(); |
| if (ret.attrs.count("node_range")) { |
| const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("node_range"); |
| node_start = range.first; |
| node_end = range.second; |
| CHECK_GE(node_start, 0); |
| CHECK_LE(node_end, idx.num_nodes()); |
| ret.attrs.erase("node_range"); |
| } |
| uint32_t entry_start = 0, entry_end = idx.num_node_entries(); |
| if (ret.attrs.count("entry_range")) { |
| const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("entry_range"); |
| entry_start = range.first; |
| entry_end = range.second; |
| CHECK_GE(entry_start, 0); |
| CHECK_LE(entry_end, idx.num_node_entries()); |
| ret.attrs.erase("entry_range"); |
| } |
| // populate the node attribute vector |
| if (dispatch_mode_name != nullptr) { |
| if (ret.attrs.count(dispatch_mode_name) != 0) { |
| dispatch_modes = ret.MoveCopyAttr<NodeAttrVector>(dispatch_mode_name); |
| } else { |
| LOG(FATAL) << "Node attribute " << dispatch_mode_name << " does not exist in the graph"; |
| } |
| } |
| |
| // Temp space for shape inference. |
| std::vector<AttrType> ishape, oshape; |
| // whether a shape is dynamic |
| std::vector<int> is_dynamic(rshape.size(), 0); |
| |
| // convert to numpy compatible shape to use operator's infer shape function |
| if (!Imperative::Get()->is_np_shape()) { |
| common::ConvertToNumpyShape(&rshape); |
| } |
| |
| // inference step function for nid |
| auto infer_step = [&](uint32_t nid, bool last_iter) { |
| if (inference_finished[nid]) return; |
| const auto& inode = idx[nid]; |
| const std::string name = inode.source->attrs.name; |
| const uint32_t num_inputs = inode.inputs.size(); |
| const uint32_t num_outputs = inode.source->num_outputs(); |
| |
| if (inode.source->is_variable()) { |
| // Variable node. No operator. Only one output entry. |
| CHECK(inode.source->op() == nullptr); |
| CHECK_EQ(num_outputs, 1U); |
| const uint32_t out_ent_id = idx.entry_id(nid, 0); |
| if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) { |
| auto it = inode.source->attrs.dict.find(shape_attr_key); |
| if (it != inode.source->attrs.dict.end()) { |
| std::istringstream is(it->second); |
| CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; |
| if (!Imperative::Get()->is_np_shape()) { |
| common::ConvertToNumpyShape(&rshape[out_ent_id]); |
| } |
| } |
| } |
| if (!fis_none(rshape[out_ent_id])) { |
| inference_finished[nid] = true; |
| } |
| // assign a default value to node attribute |
| if (dispatch_mode_name != nullptr) { |
| op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val); |
| } |
| } else if (is_backward.get(inode.source->op(), false) && |
| inode.source->control_deps.size() && bwd_identity_assign) { |
| CHECK(dispatch_mode_name == nullptr) |
| << "Backward inference for node attributes is not available"; |
| CHECK_GE(inode.source->control_deps.size(), 1U) |
| << "BackwardOp need to have control_deps to its forward op"; |
| nnvm::ObjectPtr fwd_ptr = inode.source->control_deps[0]; |
| CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; |
| |
| static auto& is_fusion_helper = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper"); |
| if (!is_fusion_helper.get(fwd_ptr->op(), false)) { |
| GetAttrFromForwardNode(nid, idx, &rshape, &inference_finished, fis_none); |
| } else { |
| GetAttrFromFusedNode<exec::FAccessSubgraphShape>(nid, idx, &rshape, |
| &inference_finished, |
| fis_none, |
| "FAccessSubgraphShape"); |
| } |
| } else { |
| DispatchMode* dispatch_mode = nullptr; |
| // Forward operator inference. |
| ishape.resize(num_inputs, empty_val); |
| bool is_input_dynamic_shape = false; |
| for (uint32_t i = 0; i < ishape.size(); ++i) { |
| ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; |
| if (!mxnet::ndim_is_known(ishape[i]) && is_dynamic[idx.entry_id(inode.inputs[i])]) { |
| is_input_dynamic_shape = true; |
| } |
| } |
| oshape.resize(num_outputs, empty_val); |
| for (uint32_t i = 0; i < oshape.size(); ++i) { |
| oshape[i] = rshape[idx.entry_id(nid, i)]; |
| } |
| if (dispatch_mode_name != nullptr) { |
| dispatch_mode = &dispatch_modes[nid]; |
| } |
| auto finfer = finfer_shape.get(inode.source->op(), fdefault); |
| if (finfer == nullptr || is_input_dynamic_shape) { |
| for (uint32_t i = 0; i < oshape.size(); ++i) { |
| if (!mxnet::ndim_is_known(oshape[i].ndim())) { |
| is_dynamic[idx.entry_id(nid, i)] = 1; |
| } |
| } |
| inference_finished[nid] = true; |
| } else { |
| // Call inference function of the operator. |
| try { |
| static auto& is_fusion = Op::GetAttr<exec::TIsFusion>("TIsFusion"); |
| if (is_fusion.get(inode.source->op(), false)) { |
| ProvideAttrToFusion<exec::FProvideSubgraphShape>(nid, idx, rshape, |
| "FProvideSubgraphShape"); |
| } |
| ApplyOpInferAttr(ret, finfer, inode.source->attrs, |
| nid, &ishape, &oshape, dispatch_mode); |
| bool finished = true; |
| for (const auto& attr : ishape) { |
| if (fis_none(attr)) finished = false; |
| } |
| for (const auto& attr : oshape) { |
| if (fis_none(attr)) finished = false; |
| } |
| inference_finished[nid] = finished; |
| } catch (const std::exception& e) { |
| throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); |
| } |
| } |
| // Save to the result map. |
| for (uint32_t i = 0; i < num_inputs; ++i) { |
| rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; |
| } |
| for (uint32_t i = 0; i < num_outputs; ++i) { |
| rshape[idx.entry_id(nid, i)] = oshape[i]; |
| } |
| } |
| }; |
| |
| size_t last_num_unknown; |
| size_t num_unknown = static_cast<size_t>(-1); // Infinity |
| bool last_iter = false; |
| bool do_next_iteration = true; |
| |
| int i = 0; |
| do { |
| if (i % 2 == 0) { |
| // forward inference |
| for (uint32_t nid = node_start; nid < node_end; ++nid) { |
| infer_step(nid, last_iter); |
| } |
| } else { |
| // backward inference |
| for (uint32_t i = node_end; i != node_start; --i) { |
| infer_step(i - 1, last_iter); |
| } |
| } |
| last_num_unknown = num_unknown; |
| num_unknown = 0; |
| for (size_t j = entry_start; j < entry_end; ++j) { |
| if (fis_none(rshape[j])) { |
| num_unknown += fnum_unknown(rshape[j]); |
| } |
| } |
| if (dispatch_mode_name) { |
| for (size_t i = node_start; i < node_end; i++) { |
| if (dispatch_modes[i] == DispatchMode::kUndefined) { |
| ++num_unknown; |
| } |
| } |
| } |
| do_next_iteration = num_unknown > 0 && last_num_unknown > num_unknown; |
| if (!do_next_iteration && !last_iter) { |
| // Check if every op agrees that it should be |
| // the end of attribute inference. If not, |
| // perform one final step |
| for (const bool done : inference_finished) { |
| do_next_iteration = do_next_iteration || !done; |
| } |
| last_iter = true; |
| } |
| ++i; |
| } while (do_next_iteration); |
| // set the shapes |
| ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape)); |
| // set the shapes |
| if (dispatch_mode_name) { |
| ret.attrs[dispatch_mode_name] = std::make_shared<any>(std::move(dispatch_modes)); |
| } |
| // number of nodes who knows the shape. |
| ret.attrs[unknown_name] = std::make_shared<any>(num_unknown); |
| return ret; |
| } |
| |
| nnvm::Graph InferShape(nnvm::Graph&& graph, |
| mxnet::ShapeVector&& shape_inputs, |
| const std::string& shape_attr_key) { |
| using dmlc::any; |
| if (shape_inputs.size() != 0) { |
| graph.attrs["shape_inputs"] = std::make_shared<any>(std::move(shape_inputs)); |
| } |
| if (shape_attr_key.length() != 0) { |
| graph.attrs["shape_attr_key"] = std::make_shared<any>(shape_attr_key); |
| } |
| return InferShapeAttr( |
| std::move(graph), mxnet::TShape(), |
| "FInferShape", "shape_inputs", "shape_attr_key", |
| "shape", "shape_num_unknown_nodes", |
| [](const mxnet::TShape& s) { return !mxnet::shape_is_known(s); }, |
| [](const mxnet::TShape& s) { |
| if (!mxnet::ndim_is_known(s)) { |
| return static_cast<size_t>(1); |
| } |
| size_t ret = 0; |
| for (const auto& val : s) { |
| if (!mxnet::dim_size_is_known(val)) { |
| ++ret; |
| } |
| } |
| return ret; |
| }, |
| nullptr, true, nullptr); |
| } |
| |
| nnvm::Graph InferType(nnvm::Graph&& graph, |
| nnvm::DTypeVector&& dtype_inputs, |
| const std::string& dtype_attr_key) { |
| using dmlc::any; |
| if (dtype_inputs.size() != 0) { |
| graph.attrs["dtype_inputs"] = std::make_shared<any>(std::move(dtype_inputs)); |
| } |
| if (dtype_attr_key.length() != 0) { |
| graph.attrs["dtype_attr_key"] = std::make_shared<any>(dtype_attr_key); |
| } |
| return InferAttr<int, nnvm::FInferType, exec::FAccessSubgraphType, |
| exec::FProvideSubgraphType>( |
| std::move(graph), -1, |
| "FInferType", "FAccessSubgraphType", "FProvideSubgraphType", |
| "dtype_inputs", "dtype_attr_key", "dtype", "dtype_num_unknown_nodes", |
| [](const int t) { return t == -1; }, |
| common::SameType, true, nullptr); |
| } |
| |
| nnvm::Graph InferStorageType(nnvm::Graph&& graph, |
| StorageTypeVector&& storage_type_inputs, |
| const std::string& storage_type_attr_key) { |
| using dmlc::any; |
| if (storage_type_inputs.size() != 0) { |
| graph.attrs["storage_type_inputs"] = std::make_shared<any>(std::move(storage_type_inputs)); |
| } |
| if (storage_type_attr_key.length() != 0) { |
| graph.attrs["storage_type_attr_key"] = std::make_shared<any>(storage_type_attr_key); |
| } |
| // initialize unknown values for dispatch modes |
| if (graph.attrs.count("dispatch_mode") == 0) { |
| DispatchModeVector dispatch_modes(graph.indexed_graph().num_nodes(), DispatchMode::kUndefined); |
| graph.attrs["dispatch_mode"] = std::make_shared<any>(std::move(dispatch_modes)); |
| } |
| // initialize the dev_mask vector from the context vector |
| if (graph.attrs.count("dev_mask") == 0) { |
| CHECK_GT(graph.attrs.count("context"), 0); |
| DevMaskVector dev_masks(graph.indexed_graph().num_nodes()); |
| const ContextVector& vctx = graph.GetAttr<ContextVector>("context"); |
| for (size_t i = 0; i < vctx.size(); i++) dev_masks[i] = vctx[i].dev_mask(); |
| graph.attrs["dev_mask"] = std::make_shared<any>(std::move(dev_masks)); |
| } |
| |
| // for storage type, the backward attr is not necessarily the same as it's correspondence |
| nnvm::Graph ret = InferAttr<int, FInferStorageType, exec::FAccessSubgraphStorageType, |
| exec::FProvideSubgraphStorageType>( |
| std::move(graph), -1, |
| "FInferStorageType", "FAccessSubgraphStorageType", "FProvideSubgraphStorageType", |
| "storage_type_inputs", "storage_type_attr_key", "storage_type", |
| "storage_type_num_unknown_nodes", |
| [](const int t) { return t == -1; }, |
| common::DefaultStorageType, false, "dispatch_mode", DispatchMode::kVariable); |
| |
| // log the storage types and dispatch modes of the graph |
| static bool log_verbose = dmlc::GetEnv("MXNET_INFER_STORAGE_TYPE_VERBOSE_LOGGING", false); |
| if (log_verbose) { |
| common::LogInferStorage(ret); |
| } |
| return ret; |
| } |
| |
| } // namespace exec |
| } // namespace mxnet |