blob: 961dade0fa93b0a3e226ac396b5051d028c30fdb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file pointwise_fusion_pass.cc
* \brief Pass applying pointwise fusion.
* \author Clement Fuji Tsang
*/
#include <mxnet/base.h>
#include <mxnet/operator.h>
#include <mxnet/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass_functions.h>
#include <algorithm>
#include <queue>
#include <chrono>
#include "./simple_partition_pass.h"
#include "../operator/fusion/fused_op-inl.h"
#include "../operator/fusion/fused_op.h"
#include "../operator/operator_common.h"
namespace mxnet {
namespace exec {
void WarnFusionNotSupported() {
static bool issued_warning = false;
if (!issued_warning) {
issued_warning = true;
#if defined(_WIN32)
LOG(WARNING) << "Omitting dynamic fused op creation- not enabled on Windows. "
<< "Unset env var MXNET_USE_FUSION=1 to quiet this message.";
#else
LOG(WARNING) << "Omitting dynamic fused op creation- needs MXNet lib built with "
<< "USE_CUDA=1 and ENABLE_CUDA_RTC=1. Unset env var MXNET_USE_FUSION=1 "
<< "to quiet this message.";
#endif // defined(_WIN32)
}
}
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
namespace {
bool IsFusionCompatible(const nnvm::Node* n) {
using namespace mxnet::fusion;
if (n->op() == nullptr)
return false;
const std::string& op_name = n->op()->name;
if (ops_desc.count(op_name))
return true;
if (slice_ops.count(op_name))
return false;
if (std::find(variable_io_ops.begin(),
variable_io_ops.end(),
op_name) !=
variable_io_ops.end())
return true;
if (op_name == "LeakyReLU") {
std::string act_type = n->attrs.dict.at("act_type");
if (LeakyReLU_ops.count(act_type))
return true;
else
return false;
}
if (op_name == "_backward_LeakyReLU") {
std::string act_type = n->attrs.dict.at("act_type");
if (LeakyReLU_bwd_ops.count(act_type))
return true;
else
return false;
}
return false;
}
bool IsInputsOnlyCompatible(const nnvm::Node* n) {
using namespace mxnet::fusion;
if (n->op() == nullptr)
return false;
const std::string& op_name = n->op()->name;
if (slice_ops.count(op_name)) {
if (op_name == "slice") {
// slice with non-default step attribute is not supported
// currently
if (n->attrs.dict.count("step") &&
!(n->attrs.dict.at("step") == "()" ||
n->attrs.dict.at("step") == "[]")) {
return false;
}
}
return true;
}
return false;
}
void CreateSubgraphNode(const nnvm::Graph& subgraph,
size_t inputs_size,
nnvm::Node* subgraph_node) {
static const Op* fused_op_ptr = Op::Get("_FusedOp");
subgraph_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>());
subgraph_node->attrs.subgraphs.back()->outputs = subgraph.outputs;
subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
subgraph_node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
subgraph_node->attrs.op = fused_op_ptr;
subgraph_node->op()->attr_parser(&(subgraph_node->attrs));
}
struct EntryInfo {
int source_node;
int index;
};
inline int SetInsert(const EntryInfo& new_elem,
std::vector<EntryInfo>* elements) {
for (size_t i = 0; i < elements->size(); ++i) {
if ((new_elem.source_node == elements->at(i).source_node) &&
(new_elem.index == elements->at(i).index)) {
return i;
}
}
elements->emplace_back(new_elem);
return elements->size() - 1;
}
} // namespace
/* \brief Create (if necessary) copy of the graph, replacing subgraphs with
* FusedOps. If there are no subgraphs to be replaced, the
* original graph is returned.
* \param g original graph.
* \param subgraph_assignment assignment of nodes in g's IndexedGraphs to
* subgraphs. Values from -1 to num_subgraphs - 1
* are allowed, -1 means that the node is not in a
* subgraph.
* \param num_subgraphs number of subgraphs.
* \param create_subgraph_node function used to prepare the subgraph node.
*/
template<typename FCreateNode>
Graph CopyAndReplaceSubgraphs(const Graph& g,
const std::vector<int>& subgraph_assignment,
const int num_subgraphs,
FCreateNode create_subgraph_node) {
if (num_subgraphs == 0) {
return g;
}
Graph ret;
const auto& idx = g.indexed_graph();
CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) <<
"Every node in the graph needs to be included in subgraph assignment.";
std::vector<nnvm::ObjectPtr> new_nodes;
new_nodes.reserve(idx.num_nodes());
struct SubgraphInfo {
nnvm::Graph graph;
nnvm::ObjectPtr subgraph_node;
std::vector<EntryInfo> outputs;
std::vector<EntryInfo> inputs;
std::vector<nnvm::ObjectPtr> input_nodes;
};
std::vector<SubgraphInfo> subgraphs(num_subgraphs);
for (auto& info : subgraphs) {
info.subgraph_node = nnvm::Node::Create();
}
for (size_t i = 0; i < idx.num_nodes(); ++i) {
// First copy the node, it will be used
// either in the new graph or inside a
// subgraph. Variables are not copied.
if (idx[i].source->op() != nullptr) {
new_nodes.emplace_back(nnvm::Node::Create());
auto& node_copy = new_nodes.back();
node_copy->attrs = idx[i].source->attrs;
node_copy->info = idx[i].source->info;
} else {
new_nodes.emplace_back(idx[i].weak_ref.lock());
continue;
}
auto& node_copy = new_nodes.back();
const int subgraph_id = subgraph_assignment[i];
if (subgraph_id != -1) {
auto& info = subgraphs[subgraph_id];
for (const auto& input : idx[i].inputs) {
const int their_subgraph = subgraph_assignment[input.node_id];
if (their_subgraph == subgraph_id) {
node_copy->inputs.emplace_back(new_nodes[input.node_id],
input.index,
input.version);
} else {
int input_num;
int output_num;
if (their_subgraph == -1) {
input_num = SetInsert({static_cast<int>(input.node_id),
static_cast<int>(input.index)}, &(info.inputs));
} else {
auto& their_subgraph_info = subgraphs[their_subgraph];
output_num = SetInsert({static_cast<int>(input.node_id),
static_cast<int>(input.index)},
&(their_subgraph_info.outputs));
input_num = SetInsert({static_cast<int>(idx.num_nodes() + their_subgraph),
output_num},
&(info.inputs));
}
if (static_cast<size_t>(input_num) == info.input_nodes.size()) {
info.input_nodes.emplace_back(nnvm::Node::Create());
info.input_nodes.back()->attrs.name = "input_" + std::to_string(input_num);
if (their_subgraph == -1) {
info.subgraph_node->inputs.emplace_back(new_nodes[input.node_id],
input.index,
input.version);
} else {
info.subgraph_node->inputs.emplace_back(subgraphs[their_subgraph].subgraph_node,
output_num,
input.version);
}
}
node_copy->inputs.emplace_back(info.input_nodes[input_num], 0, 0);
}
}
} else {
for (const auto& input : idx[i].inputs) {
const int subgraph_id = subgraph_assignment[input.node_id];
if (subgraph_id == -1) {
node_copy->inputs.emplace_back(new_nodes[input.node_id],
input.index,
input.version);
} else {
auto& info = subgraphs[subgraph_id];
const int output_num = SetInsert({static_cast<int>(input.node_id),
static_cast<int>(input.index)},
&(info.outputs));
node_copy->inputs.emplace_back(info.subgraph_node,
output_num,
input.version);
}
}
}
// Control deps
for (const auto& dep : idx[i].control_deps) {
if (subgraph_id == subgraph_assignment[dep]) {
node_copy->control_deps.emplace_back(new_nodes[dep]);
}
}
}
ret.outputs.reserve(idx.outputs().size());
for (const auto& output : idx.outputs()) {
const int subgraph_id = subgraph_assignment[output.node_id];
if (subgraph_id == -1) {
ret.outputs.emplace_back(new_nodes[output.node_id],
output.index,
output.version);
} else {
const int output_num = SetInsert({static_cast<int>(output.node_id),
static_cast<int>(output.index)},
&(subgraphs[subgraph_id].outputs));
ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node,
output_num,
output.version);
}
}
for (auto& info : subgraphs) {
info.graph.outputs.reserve(info.outputs.size());
for (const auto& entry_info : info.outputs) {
info.graph.outputs.emplace_back(new_nodes[entry_info.source_node],
entry_info.index,
0);
}
create_subgraph_node(info.graph, info.inputs.size(), info.subgraph_node.get());
}
for (size_t i = 0; i < idx.num_nodes(); ++i) {
// Add _FusedOpHelper nodes
const int subgraph_id = subgraph_assignment[i];
for (size_t dep_num = 0; dep_num < idx[i].control_deps.size(); ++dep_num) {
const auto& dep = idx[i].control_deps[dep_num];
const int their_subgraph_id = subgraph_assignment[dep];
if (subgraph_id != -1 && their_subgraph_id == -1) {
// Not in any subgraph, use FusedOpOutHelper
auto& info = subgraphs[subgraph_id];
size_t node_id = info.subgraph_node->control_deps.size();
info.subgraph_node->control_deps.emplace_back(new_nodes[dep]);
auto helper_node = op::MakeNode("_FusedOpOutHelper",
"FusedOp_" + new_nodes[i]->attrs.name + "_outhelper",
nullptr,
nullptr,
nullptr);
helper_node->attrs.parsed =
FusedOpHelperParamPtr(new FusedOpHelperParam(
nnvm::get<FusedOpPtr>(info.subgraph_node->attrs.parsed),
node_id));
new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() + dep_num,
std::move(helper_node));
} else if (their_subgraph_id != subgraph_id &&
their_subgraph_id != -1) {
auto& info = subgraphs[their_subgraph_id];
const auto& subgraph_idx = info.graph.indexed_graph();
uint32_t node_id = subgraph_idx.node_id(new_nodes[dep].get());
auto helper_node = op::MakeNode("_FusedOpHelper",
info.subgraph_node->attrs.name + "_"
+ idx[i].source->attrs.name + "_helper",
nullptr,
nullptr,
nullptr);
helper_node->attrs.parsed =
FusedOpHelperParamPtr(new FusedOpHelperParam(
nnvm::get<FusedOpPtr>(info.subgraph_node->attrs.parsed),
node_id));
new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() + dep_num,
std::move(helper_node));
}
}
}
for (auto& info : subgraphs) {
const auto& idx = info.graph.indexed_graph();
const auto& input_nodes = idx.input_nodes();
std::vector<nnvm::NodeEntry> subgraph_inputs;
subgraph_inputs.reserve(info.subgraph_node->inputs.size());
for (const int input : input_nodes) {
for (size_t i = 0; i < info.input_nodes.size(); ++i) {
const auto& input_ptr = info.input_nodes[i].get();
if (input_ptr == idx[input].source) {
subgraph_inputs.emplace_back(info.subgraph_node->inputs[i]);
}
}
}
info.subgraph_node->inputs.swap(subgraph_inputs);
std::string name;
for (size_t i = 0; i < idx.num_nodes(); ++i) {
if (idx[i].source->op() != nullptr) {
name += idx[i].source->op()->name + "_";
}
}
info.subgraph_node->attrs.name = name;
}
return ret;
}
Graph FusePointwise(const Graph &g, const size_t num_forward_outputs) {
auto start = std::chrono::steady_clock::now();
std::vector<int> subset_assignment;
int num_subsets;
std::tie(subset_assignment, num_subsets) = GetCompatibleSubsets(g, num_forward_outputs, // NOLINT(*)
IsFusionCompatible,
IsInputsOnlyCompatible);
Graph ret = CopyAndReplaceSubgraphs(g, subset_assignment, num_subsets,
CreateSubgraphNode);
auto end = std::chrono::steady_clock::now();
if (dmlc::GetEnv("MXNET_RTC_VERBOSE", false)) {
auto diff = end - start;
LOG(INFO) << "Pointwise fusion graph pass took: "
<< std::chrono::duration<double, std::milli>(diff).count()
<< "ms.";
}
return ret;
}
#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
} // namespace exec
} // namespace mxnet