blob: 477be4bb96cc9d856c868d2e69630482fe930892 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef MXNET_OPERATOR_SUBGRAPH_COMMON_H_
#define MXNET_OPERATOR_SUBGRAPH_COMMON_H_
#include <string>
#include <set>
#include <vector>
#include "../elemwise_op_common.h"
#include "../../imperative/exec_pass.h"
namespace mxnet {
namespace op {
enum SelectStatus { kFail = 0, kStart, kSuccess };
inline uint32_t DefaultSubgraphOpNumInputs(const nnvm::NodeAttrs& attrs) {
const nnvm::Symbol& sym = *attrs.subgraphs[0];
return sym.ListInputNames(nnvm::Symbol::kAll).size();
}
inline uint32_t DefaultSubgraphOpNumOutputs(const nnvm::NodeAttrs& attrs) {
const nnvm::Symbol& sym = *attrs.subgraphs[0];
return sym.ListOutputNames().size();
}
inline std::vector<std::string> DefaultSubgraphOpListInputs(const nnvm::NodeAttrs& attrs) {
const nnvm::Symbol& sym = *attrs.subgraphs[0];
return sym.ListInputNames(nnvm::Symbol::kAll);
}
inline std::vector<std::string> DefaultSubgraphOpListOutputs(const nnvm::NodeAttrs& attrs) {
const nnvm::Symbol& sym = *attrs.subgraphs[0];
return sym.ListOutputNames();
}
inline bool DefaultSubgraphOpShapeHelper(const nnvm::Symbol& subgraph_sym,
mxnet::ShapeVector* in_shapes,
mxnet::ShapeVector* out_shapes) {
using namespace exec;
nnvm::Graph g;
g.outputs = subgraph_sym.outputs;
const auto& idx_g = g.indexed_graph();
CHECK_EQ(idx_g.input_nodes().size(), in_shapes->size());
CHECK_EQ(idx_g.outputs().size(), out_shapes->size());
// Put the input and output shapes to the shape vector.
mxnet::ShapeVector shapes(idx_g.num_node_entries());
const auto& input_nids = idx_g.input_nodes();
CHECK_EQ(input_nids.size(), in_shapes->size());
for (size_t i = 0; i < in_shapes->size(); i++) {
auto eid = idx_g.entry_id(input_nids[i], 0);
shapes[eid] = in_shapes->at(i);
}
CHECK_EQ(g.outputs.size(), out_shapes->size());
for (size_t i = 0; i < out_shapes->size(); i++) {
auto eid = idx_g.entry_id(g.outputs[i]);
shapes[eid] = out_shapes->at(i);
}
// Infer shape of the graph.
g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
g = exec::InferShape(std::move(g));
// Copy the inferred shape back to the input shapes and the output shapes.
shapes = g.GetAttr<mxnet::ShapeVector>("shape");
// assign to in_shapes
for (size_t i = 0; i < in_shapes->size(); ++i) {
const auto eid = idx_g.entry_id(input_nids[i], 0);
SHAPE_ASSIGN_CHECK(*in_shapes, i, shapes[eid]);
}
// assign to out_shapes
for (size_t i = 0; i < g.outputs.size(); ++i) {
const auto eid = idx_g.entry_id(g.outputs[i]);
SHAPE_ASSIGN_CHECK(*out_shapes, i, shapes[eid]);
}
// Check if we have inferred the shapes correctly.
return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
}
inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_shapes,
mxnet::ShapeVector* out_shapes) {
return DefaultSubgraphOpShapeHelper(*attrs.subgraphs[0], in_shapes, out_shapes);
}
inline bool DefaultSubgraphOpTypeHelper(const nnvm::Symbol& subgraph_sym,
std::vector<int>* in_types,
std::vector<int>* out_types) {
nnvm::Graph g;
g.outputs = subgraph_sym.outputs;
const auto& idx_g = g.indexed_graph();
CHECK_EQ(idx_g.input_nodes().size(), in_types->size());
CHECK_EQ(idx_g.outputs().size(), out_types->size());
// Put the input and output data types to the dtype vector.
nnvm::DTypeVector types(idx_g.num_node_entries(), -1);
const auto& input_nids = idx_g.input_nodes();
CHECK_EQ(input_nids.size(), in_types->size());
for (size_t i = 0; i < in_types->size(); i++) {
auto eid = idx_g.entry_id(input_nids[i], 0);
types[eid] = in_types->at(i);
}
CHECK_EQ(g.outputs.size(), out_types->size());
for (size_t i = 0; i < out_types->size(); i++) {
auto eid = idx_g.entry_id(g.outputs[i]);
types[eid] = out_types->at(i);
}
// Infer data type of the graph.
g.attrs["dtype"] = std::make_shared<dmlc::any>(std::move(types));
g = exec::InferType(std::move(g));
types = g.GetAttr<nnvm::DTypeVector>("dtype");
// assign to in_types
for (size_t i = 0; i < in_types->size(); ++i) {
const auto eid = idx_g.entry_id(input_nids[i], 0);
TYPE_ASSIGN_CHECK(*in_types, i, types[eid]);
}
// assign to out_types
for (size_t i = 0; i < g.outputs.size(); ++i) {
const auto eid = idx_g.entry_id(g.outputs[i]);
TYPE_ASSIGN_CHECK(*out_types, i, types[eid]);
}
// Check if we have inferred the dtypes correctly.
return g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0;
}
inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_types,
std::vector<int>* out_types) {
return DefaultSubgraphOpTypeHelper(*attrs.subgraphs[0], in_types, out_types);
}
inline bool DefaultSubgraphOpStorageTypeHelper(const nnvm::Symbol& subgraph_sym,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_stypes,
std::vector<int>* out_stypes) {
nnvm::Graph g;
g.outputs = subgraph_sym.outputs;
const auto& idx_g = g.indexed_graph();
CHECK_EQ(idx_g.input_nodes().size(), in_stypes->size());
CHECK_EQ(idx_g.outputs().size(), out_stypes->size());
exec::DevMaskVector dev_masks(idx_g.num_node_entries(), dev_mask);
// Put the input and output storages to the storage vector.
StorageTypeVector stypes(idx_g.num_node_entries(), kUndefinedStorage);
const auto& input_nids = idx_g.input_nodes();
CHECK_EQ(input_nids.size(), in_stypes->size());
for (size_t i = 0; i < in_stypes->size(); i++) {
auto eid = idx_g.entry_id(input_nids[i], 0);
stypes[eid] = in_stypes->at(i);
}
CHECK_EQ(g.outputs.size(), out_stypes->size());
for (size_t i = 0; i < out_stypes->size(); i++) {
auto eid = idx_g.entry_id(g.outputs[i]);
stypes[eid] = out_stypes->at(i);
}
// Infer storage type of the graph.
bool dev_match =
g.attrs.count("dev_mask") && g.GetAttr<exec::DevMaskVector>("dev_mask") == dev_masks;
if (!dev_match) {
g.attrs["dev_mask"] = std::make_shared<dmlc::any>(std::move(dev_masks));
}
g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(stypes));
g = exec::InferStorageType(std::move(g));
stypes = g.GetAttr<StorageTypeVector>("storage_type");
// assign to in_types
for (size_t i = 0; i < in_stypes->size(); ++i) {
const auto eid = idx_g.entry_id(input_nids[i], 0);
STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, stypes[eid]);
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
// assign to out_types
for (size_t i = 0; i < g.outputs.size(); ++i) {
const auto eid = idx_g.entry_id(g.outputs[i]);
STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, stypes[eid]);
}
// Check if we have inferred the storages correctly.
return g.GetAttr<size_t>("storage_type_num_unknown_nodes") == 0;
}
inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_stypes,
std::vector<int>* out_stypes) {
return DefaultSubgraphOpStorageTypeHelper(
*attrs.subgraphs[0], dev_mask, dispatch_mode, in_stypes, out_stypes);
}
inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) {
return ExecType::kSubgraphExec;
}
inline std::vector<uint32_t> DefaultSubgraphOpMutableInputsHelper(
const nnvm::Symbol& subgraph_sym) {
const std::vector<std::string> input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kAll);
const std::vector<std::string> immutable_input_names =
subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
const std::vector<std::string> mutable_input_names =
subgraph_sym.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
CHECK_EQ(immutable_input_names.size() + mutable_input_names.size(), input_names.size());
std::vector<uint32_t> ret;
size_t i1 = 0, i2 = 0;
for (size_t i = 0; i < input_names.size(); ++i) {
if (i1 < immutable_input_names.size() && input_names[i] == immutable_input_names[i1]) {
++i1;
} else {
CHECK(i2 < mutable_input_names.size());
CHECK_EQ(input_names[i], mutable_input_names[i2]);
++i2;
ret.push_back(i);
}
}
return ret;
}
inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) {
return DefaultSubgraphOpMutableInputsHelper(*attrs.subgraphs[0]);
}
inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequestHelper(
const nnvm::Symbol& subgraph_sym) {
static auto& fresource = Op::GetAttr<FResourceRequest>("FResourceRequest");
std::set<ResourceRequest::Type> resource_types;
DFSVisit(subgraph_sym.outputs, [&](const nnvm::ObjectPtr& node) {
if (!node->is_variable() && fresource.count(node->op())) {
for (ResourceRequest& r : fresource[node->op()](node->attrs)) {
resource_types.insert(r.type);
}
}
});
return std::vector<ResourceRequest>(resource_types.begin(), resource_types.end());
}
inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) {
return DefaultSubgraphOpResourceRequestHelper(*attrs.subgraphs[0]);
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_SUBGRAPH_COMMON_H_