blob: 9b26d366f2abc3c155385d97f142a1b7fb39ecf0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file low_precision_pass.cc
* \brief Return new graph with amp_cast and amp_multicast operators added wherever required
*/
#include <nnvm/node.h>
#include <nnvm/graph.h>
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <mxnet/base.h>
#include <algorithm>
#include <functional>
#include "operator/operator_common.h"
#include "common/utils.h"
namespace mxnet {
using nnvm::Graph;
using nnvm::Node;
using nnvm::NodeEntry;
using nnvm::ObjectPtr;
bool IsCastOp(const nnvm::Op* const op) {
return op && (op == Op::Get("amp_cast") || op == Op::Get("Cast"));
}
/*!
* \brief Before the model conversion, node entries of the original graph are mapped to the
* equivalent node entries in the new graph that will be then converted to a mixed precision graph.
* This class wraps a mapped NodeEntry from the new graph, providing a transparent interface for
* acquiring versions of the wrapped entry with a specific dtype, adding a casting nodes to the
* graph when needed (one for each unique dtype that was requested).
*/
class MappedNodeEntry {
public:
MappedNodeEntry(NodeEntry node_entry, const int original_dtype)
: entry(std::move(node_entry)), original_dtype(original_dtype) {
dtype = original_dtype;
}
/*!
* \brief Converts the dtype of this NodeEntry. This should be called after a node has been
* converted and dtypes of its outputs may have changed
*/
void UpdateDTypeAfterConversion(const int new_dtype) {
CHECK(dtype == original_dtype || dtype == new_dtype); // dtype should be changed only once
CHECK(entry.node->op());
CHECK_NE(new_dtype, -1);
dtype = new_dtype;
}
/*!
* \brief If dtype of this NodeEntry was not changed, returns the mapped entry. Otherwise returns
* a NodeEntry to the node which casts to the original dtype of this NodeEntry.
*/
const NodeEntry& AsOriginal() {
return AsType(original_dtype);
}
/*!
* \brief If dtype of this NodeEntry matches the specified dtype, returns the mapped entry.
* Otherwise returns a NodeEntry to the node which casts to that type.
*/
const NodeEntry& AsType(const int target_dtype, const bool can_add_cast = true) {
if (dtype == target_dtype || target_dtype == -1) {
return entry;
}
NodeEntry& cast_entry = casts[target_dtype];
if (cast_entry.node == nullptr) {
CHECK(can_add_cast);
cast_entry = Cast(target_dtype);
CHECK(cast_entry.node);
}
return cast_entry;
}
/*! \brief Returns whether this entry has the specified dtype or an existing cast to that dtype */
bool HasDTypeEntry(const int target_dtype) const {
CHECK_NE(target_dtype, -1);
return dtype == target_dtype || casts.count(target_dtype) > 0;
}
/*!
* \brief Returns whether this entry can be cast to a specific dtype. This should be called on
* input entires of a node before its conversion.
*/
bool CanBeCastTo(const int target_dtype) {
CHECK_NE(target_dtype, -1);
static const auto& amp_cast_op = Op::Get("amp_cast");
static const auto& infertype = nnvm::Op::GetAttr<nnvm::FInferType>("FInferType")[amp_cast_op];
nnvm::NodeAttrs dummy_atts;
dummy_atts.dict["dtype"] = mxnet::op::type_string(target_dtype);
amp_cast_op->attr_parser(&dummy_atts);
std::vector<int> in_types = {dtype};
std::vector<int> out_types = {-1};
return infertype(dummy_atts, &in_types, &out_types);
}
/*! \brief Returns whether this NodeEntry (of a parameter) can be cast offline */
bool CanBeCastOfflineTo(const int target_dtype) const {
CHECK(entry.node->is_variable());
CHECK_NE(target_dtype, -1);
return casts.count(target_dtype) > 0;
}
private:
ObjectPtr CreateCastNode(const std::string& op_name, const std::string& node_name) {
CHECK_GT(op_name.size(), 0);
ObjectPtr node = Node::Create();
node->attrs.name = node_name;
node->attrs.op = Op::Get(op_name);
node->inputs.emplace_back(entry);
return node;
}
NodeEntry Cast(const int target_dtype) {
CHECK(CanBeCastTo(target_dtype));
const std::string dt_name = mxnet::op::type_string(target_dtype);
const std::string suffix = "_" + std::to_string(entry.index);
const std::string cast_node_name = entry.node->attrs.name + suffix + "_amp_cast_" + dt_name;
ObjectPtr cast_node = CreateCastNode("amp_cast", cast_node_name);
cast_node->attrs.dict["dtype"] = dt_name;
cast_node->op()->attr_parser(&(cast_node->attrs));
return NodeEntry{std::move(cast_node), 0, 0};
}
public:
const NodeEntry entry;
const int original_dtype; // original dtype of the entry
private:
int dtype; // current dtype of the entry
std::unordered_map<int, NodeEntry> casts;
};
using EntryMap_t = nnvm::NodeEntryMap<MappedNodeEntry>;
using NodeMap_t = std::unordered_map<Node*, ObjectPtr>;
using NodeEntrySet_t = std::unordered_set<NodeEntry, nnvm::NodeEntryHash, nnvm::NodeEntryEqual>;
using NodesEntries_t = std::unordered_map<Node*, NodeEntrySet_t>;
using DstNodes_t = std::unordered_map<Node*, std::unordered_map<Node*, NodeEntry>>;
/*! \brief Makes sure the node in the new graph will work with the same precision as in the original
* graph */
static void KeepOriginalNode(const ObjectPtr& old_node,
const NodeMap_t& node_map,
EntryMap_t* const entry_map) {
const ObjectPtr& new_node = node_map.at(old_node.get());
for (const auto& old_ne : old_node->inputs) {
new_node->inputs.push_back(entry_map->at(old_ne).AsOriginal());
}
}
/*! \brief Tries to convert the node to low precision. Returns whether the node has been
* successfully converted
*/
static bool TryLowPrecision(const int target_dtype,
const ObjectPtr& old_node,
const NodeMap_t& node_map,
const NodesEntries_t& nodes_entries,
EntryMap_t* const entry_map) {
static const auto& infertype = nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
static const auto& fmutate_inputs = Op::GetAttr<nnvm::FMutateInputs>("FMutateInputs");
std::vector<int> in_types(old_node->inputs.size(), -1);
bool has_lp_input = false;
for (int i = 0; i < old_node->inputs.size(); ++i) {
if (entry_map->at(old_node->inputs[i]).HasDTypeEntry(target_dtype)) {
in_types[i] = target_dtype;
has_lp_input = true;
}
}
if (!has_lp_input) {
// when inputs are not already in low precision, assume the first input should be in low
// precision in order to convert this op
in_types[0] = target_dtype;
}
// infer types of other inputs
std::vector<int> out_types(old_node->num_outputs(), -1);
if (infertype.count(old_node->op()) == 0 ||
infertype[old_node->op()](old_node->attrs, &in_types, &out_types) == false) {
return false;
}
if (fmutate_inputs.count(old_node->op()) != 0) {
std::vector<uint32_t> mutable_inputs = fmutate_inputs[old_node->op()](old_node->attrs);
for (size_t i = 0; i < old_node->inputs.size(); ++i) {
if (in_types[i] == target_dtype) {
if (std::find(mutable_inputs.begin(), mutable_inputs.end(), i) != mutable_inputs.end()) {
return false;
}
}
}
}
for (size_t i = 0; i < old_node->inputs.size(); ++i) {
MappedNodeEntry& mapped_ne = entry_map->at(old_node->inputs[i]);
// if this tensor needs a cast, check whether MappedNodeEntry can actually cast it
if (in_types[i] != -1 && !mapped_ne.HasDTypeEntry(in_types[i]) &&
!mapped_ne.CanBeCastTo(in_types[i])) {
return false;
}
}
const ObjectPtr& new_node = node_map.at(old_node.get());
for (size_t i = 0; i < old_node->inputs.size(); ++i) {
new_node->inputs.push_back(entry_map->at(old_node->inputs[i]).AsType(in_types[i]));
}
for (const NodeEntry& old_ne : nodes_entries.at(old_node.get())) {
entry_map->at(old_ne).UpdateDTypeAfterConversion(out_types[old_ne.index]);
}
return true;
}
/*! \brief Tries to convert the node to low precision if all of its inputs already have the correct
* dtype. Otherwise keeps the node unchanged.
*/
static void HandleWidestDtypeNode(const int target_dtype,
const ObjectPtr& old_node,
const NodeMap_t& node_map,
const NodesEntries_t& nodes_entries,
EntryMap_t* const entry_map) {
static const auto& infertype = nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
// gather info about current dtypes of inputs
// if there is already at least one input with target dtype, we try converting to low precision
bool try_lp = false;
std::vector<int> in_types(old_node->inputs.size(), -1);
for (int i = 0; i < old_node->inputs.size(); ++i) {
if (entry_map->at(old_node->inputs[i]).HasDTypeEntry(target_dtype)) {
in_types[i] = target_dtype; // set only lp inputs
try_lp = true;
}
}
if (try_lp) {
// run infertype, to see what other input types this op needs with the current lp inputs
std::vector<int> out_types(old_node->num_outputs(), -1);
try_lp = (infertype.count(old_node->op()) > 0 &&
infertype[old_node->op()](old_node->attrs, &in_types, &out_types));
if (try_lp) {
// if we have to add casts to inputs, this op shouldn't run in low precision
for (int i = 0; i < old_node->inputs.size(); ++i) {
const NodeEntry& old_input_ne = old_node->inputs[i];
if (in_types[i] != -1 && !entry_map->at(old_input_ne).HasDTypeEntry(in_types[i])) {
try_lp = false;
break;
}
}
if (try_lp && TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, entry_map)) {
return;
}
}
}
KeepOriginalNode(old_node, node_map, entry_map);
}
/*!
* \brief Tries to convert the node to low precision if some of its inputs already are converted.
* Otherwise keeps the node unchanged.
*/
void HandleDTypeNeutralNode(const int target_dtype,
const ObjectPtr& old_node,
const NodeMap_t& node_map,
const NodesEntries_t& nodes_entries,
EntryMap_t* const entry_map) {
const auto& is_lp = [&](const auto& old_ne) {
return entry_map->at(old_ne).HasDTypeEntry(target_dtype);
};
if (!std::any_of(old_node->inputs.begin(), old_node->inputs.end(), is_lp) ||
!TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, entry_map)) {
KeepOriginalNode(old_node, node_map, entry_map);
}
}
/* \brief Decides which prameters can be cast offline and removes redundant cast nodes from the
* graph */
static void RemoveParamCasts(const int target_dtype,
const std::string& offline_param_cast_attr,
const NodeMap_t& node_map,
const DstNodes_t& old_param_dst_nodes,
EntryMap_t* entry_map) {
for (const auto& [old_param, old_param_dsts] : old_param_dst_nodes) {
const ObjectPtr& new_param = node_map.at(old_param);
const auto& can_be_cast_offline = [&](const std::pair<Node*, NodeEntry>& old_param_dst) {
const ObjectPtr& param_dst_node = node_map.at(old_param_dst.first);
const MappedNodeEntry& param_mapped_ne = entry_map->at(old_param_dst.second);
for (const NodeEntry& node_entry : param_dst_node->inputs) {
if (node_entry.node == new_param) {
return false;
}
}
return param_mapped_ne.CanBeCastOfflineTo(target_dtype);
};
if (std::all_of(old_param_dsts.begin(), old_param_dsts.end(), can_be_cast_offline)) {
nnvm::NodeEntryEqual are_equal;
for (const auto& [old_dst_node, old_ne] : old_param_dsts) {
MappedNodeEntry& mapped_ne = entry_map->at(old_ne);
const NodeEntry& new_ne_to_skip = mapped_ne.AsType(target_dtype, false);
const ObjectPtr& new_dst_node = node_map.at(old_dst_node);
bool skipped_amp_cast = false;
for (NodeEntry& new_ne : new_dst_node->inputs) {
if (are_equal(new_ne, new_ne_to_skip)) {
new_ne = mapped_ne.entry;
skipped_amp_cast = true;
break;
}
}
CHECK(skipped_amp_cast);
}
new_param->attrs.dict[offline_param_cast_attr] = mxnet::op::type_string(target_dtype);
}
}
}
Graph ReducePrecision(Graph&& src) {
const auto target_dtype = src.GetAttr<int>("target_dtype");
const auto cast_params_offline = src.GetAttr<int>("cast_params_offline");
const auto& offline_param_cast_attr = src.GetAttr<std::string>("offline_param_cast_attr");
const auto& input_names = src.GetAttr<std::unordered_set<std::string>>("input_names");
const auto& target_dtype_ops = src.GetAttr<std::unordered_set<std::string>>("target_dtype_ops");
const auto& fp32_ops = src.GetAttr<std::unordered_set<std::string>>("fp32_ops");
const auto& widest_dtype_ops = src.GetAttr<std::unordered_set<std::string>>("widest_dtype_ops");
auto src_dtypes = src.GetAttr<nnvm::DTypeVector>("dtype"); // copy, not reference
CHECK(target_dtype == mshadow::kFloat16 || target_dtype == mshadow::kBfloat16)
<< "Only float16 and bfloat16 target_dtype is supported yet," << target_dtype;
const nnvm::IndexedGraph& src_idx = src.indexed_graph();
CHECK_EQ(src_dtypes.size(), src_idx.num_node_entries());
for (const int src_dtype : src_dtypes) {
CHECK_NE(src_dtype, -1) << "Infer type failed with full information about input types";
}
NodeMap_t node_map;
EntryMap_t entry_map;
NodesEntries_t nodes_entries;
DstNodes_t old_param_dst_nodes;
const auto& register_node_entry =
[&](const NodeEntry& old_ne, const ObjectPtr& old_dst_node, const ObjectPtr& new_dst_node) {
// new_dst_node is the node that should own `old_ne` equivalent as one of its input
const uint32_t entry_id = src_idx.entry_id(old_ne);
const int original_ne_dtype = src_dtypes[entry_id];
const ObjectPtr& old_src_node = old_ne.node;
const ObjectPtr& new_src_node = node_map.at(old_src_node.get());
const NodeEntry new_ne = NodeEntry(new_src_node, old_ne.index, old_ne.version);
entry_map.emplace(old_ne, MappedNodeEntry(new_ne, original_ne_dtype));
// register which nodes use parameters
nodes_entries[old_src_node.get()].insert(old_ne);
if (new_dst_node && old_src_node->is_variable() &&
input_names.count(old_src_node->attrs.name) == 0) {
CHECK(old_dst_node);
old_param_dst_nodes[old_src_node.get()][old_dst_node.get()] = old_ne;
}
};
// gather information about node entries and build a new graph
DFSVisit(src.outputs, [&](const ObjectPtr& old_node) {
ObjectPtr new_node = Node::Create(*old_node);
new_node->inputs.clear();
for (const NodeEntry& old_ne : old_node->inputs) {
register_node_entry(old_ne, old_node, new_node);
}
node_map.emplace(old_node.get(), std::move(new_node));
});
for (const NodeEntry& old_out_ne : src.outputs) {
register_node_entry(old_out_ne, nullptr, nullptr);
}
// convert the model
const auto convert_node_fn = [&](const ObjectPtr& old_node) {
if (old_node->is_variable() || old_node->op() == Op::Get("amp_multicast") ||
IsCastOp(old_node->op())) {
const ObjectPtr& new_node = node_map.at(old_node.get());
for (const auto& old_ne : old_node->inputs) {
const ObjectPtr& new_in_node = node_map.at(old_ne.node.get());
new_node->inputs.emplace_back(new_in_node, old_ne.index, old_ne.version);
}
return;
}
auto opt_constraints =
common::flag_attr_accumulate<OptConstraint_int_t>(old_node->attrs, OPT_CONSTRAINT_ATTR);
if (fp32_ops.count(old_node->op()->name) > 0 ||
(opt_constraints & static_cast<OptConstraint_int_t>(OptConstraint::DisableAMP))) {
KeepOriginalNode(old_node, node_map, &entry_map);
} else if (target_dtype_ops.count(old_node->op()->name) > 0) {
if (!TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, &entry_map)) {
LOG(WARNING) << "Low precision conversion failure. Node '" + old_node->attrs.name +
"' will not be converted.";
KeepOriginalNode(old_node, node_map, &entry_map);
}
} else if (widest_dtype_ops.count(old_node->op()->name) > 0) {
HandleWidestDtypeNode(target_dtype, old_node, node_map, nodes_entries, &entry_map);
} else {
HandleDTypeNeutralNode(target_dtype, old_node, node_map, nodes_entries, &entry_map);
}
};
// Because some nodes depend on casts present in the graph, the order of visited nodes will
// determine whether some nodes are converted or not. To avoid this, first we make a virtual
// conversion pass in order to have all the necessary casts already present (in the
// MappedNodeEntry instances) during the second (true) conversion pass
// virtual conversion pass
DFSVisit(src.outputs, [&](const ObjectPtr& old_node) {
convert_node_fn(old_node);
node_map[old_node.get()]->inputs.clear(); // make this pass "virtual" by removing edges
});
// true conversion pass
DFSVisit(src.outputs, [&](const ObjectPtr& old_node) { convert_node_fn(old_node); });
std::vector<NodeEntry> outputs;
for (const auto& out_ne : src.outputs) {
outputs.push_back(entry_map.at(out_ne).AsOriginal());
}
if (cast_params_offline) {
RemoveParamCasts(
target_dtype, offline_param_cast_attr, node_map, old_param_dst_nodes, &entry_map);
}
Graph ret;
ret.outputs = std::move(outputs);
return ret;
}
NNVM_REGISTER_PASS(ReducePrecision)
.describe("add cast layers for low precision inference")
.set_body(ReducePrecision)
.set_change_graph(true);
} // namespace mxnet