blob: 3a38ee5c5e2d3c8efb15a0321198ea3597d07c9e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file alm.cc
* \brief Automatic Layout Manager
* \author Dawid Tracz, Vladimir Cherepanov
*/
#include "alm.h"
#include <algorithm>
#include <sstream>
#include <unordered_set>
#include <utility>
#include "../operator/nn/convolution-inl.h"
#include "../operator/nn/deconvolution-inl.h"
#include "../operator/tensor/matrix_op-inl.h"
namespace mxnet {
namespace alm {
namespace {
nnvm::ObjectPtr CreateTransposeNode(const std::string& name, const alm::Transpose& axes) {
nnvm::ObjectPtr newptr = nnvm::Node::Create();
newptr->attrs.op = nnvm::Op::Get("transpose");
newptr->attrs.name = name;
// set tranpose axes
std::ostringstream ss;
ss << mxnet::TShape(axes.begin(), axes.end());
newptr->attrs.dict["axes"] = ss.str();
newptr->op()->attr_parser(&(newptr->attrs));
return newptr;
}
mshadow::LayoutFlag TargetLayout(const nnvm::ObjectPtr& node) {
static const Op* conv_op = Op::Get("Convolution");
static const Op* deconv_op = Op::Get("Deconvolution");
static const std::unordered_map<int, mshadow::LayoutFlag> ndim2layout{
{1, mshadow::kNWC},
{2, mshadow::kNHWC},
{3, mshadow::kNDHWC},
};
auto target_layout = [](const auto& param) {
auto it = ndim2layout.find(param.kernel.ndim());
CHECK(it != ndim2layout.end()) << "Unexpected kernel dimensions: " << param.kernel;
return it->second;
};
if (node->op() == conv_op)
return target_layout(nnvm::get<op::ConvolutionParam>(node->attrs.parsed));
if (node->op() == deconv_op)
return target_layout(nnvm::get<op::DeconvolutionParam>(node->attrs.parsed));
return mshadow::kUNKNOWN;
}
} // namespace
nnvm::Graph OptimizeLayout(nnvm::Graph&& g) {
static const auto& op_map = Op::GetAttr<mxnet::alm::FChangeLayout>("FChangeLayout");
static const Op* transpose_op = Op::Get("transpose");
std::unordered_set<nnvm::ObjectPtr> outputs;
for (auto& o : g.outputs)
outputs.insert(o.node);
nnvm::NodeEntryMap<alm::Transpose> changed;
struct ToDelete {
nnvm::ObjectPtr node; // output of the transpose
size_t input_idx;
};
std::vector<ToDelete> to_delete;
struct ToAdd {
nnvm::ObjectPtr node;
size_t input_idx;
alm::Transpose axes;
};
std::vector<ToAdd> to_add;
DFSVisit(g.outputs, [&outputs, &changed, &to_add, &to_delete](const nnvm::ObjectPtr& node) {
std::vector<alm::Transpose> input_axes(node->inputs.size());
for (size_t i = 0; i < node->inputs.size(); ++i) {
if (node->inputs[i].node->op() == transpose_op) {
const auto& param = nnvm::get<op::TransposeParam>(node->inputs[i].node->attrs.parsed);
if (IsIdentity(FromTShape(param.axes))) {
to_delete.push_back({node, i});
continue;
}
}
auto it = changed.find(node->inputs[i]);
if (it == changed.end())
continue;
input_axes[i] = it->second;
}
auto fchange = op_map.get(node->op(), nullptr);
if (fchange && outputs.count(node) == 0) {
std::vector<alm::Transpose> output_axes;
if (fchange(&node->attrs, TargetLayout(node), &input_axes, &output_axes))
node->op()->attr_parser(&node->attrs);
for (size_t i = 0; i < output_axes.size(); ++i) {
if (IsIdentity(output_axes[i]))
continue;
changed.insert(std::make_pair(nnvm::NodeEntry(node, i, 0), output_axes[i]));
}
}
for (size_t i = 0; i < input_axes.size(); ++i) {
if (IsIdentity(input_axes[i]))
continue;
to_add.push_back({node, i, input_axes[i]});
}
});
for (const auto& t : to_delete) {
auto& tnode = t.node->inputs[t.input_idx].node;
CHECK_EQ(tnode->inputs.size(), 1);
t.node->inputs[t.input_idx] = tnode->inputs[0];
}
size_t node_no = 0;
for (const auto& t : to_add) {
auto tnode = CreateTransposeNode("ALM_transpose_" + std::to_string(node_no++), t.axes);
tnode->inputs.push_back(t.node->inputs[t.input_idx]);
t.node->inputs[t.input_idx] = nnvm::NodeEntry(tnode);
}
nnvm::Graph ret;
ret.outputs = g.outputs;
return ret;
}
Transpose Reverse(const Transpose& axes) {
Transpose rev(axes.size());
for (size_t i = 0; i < rev.size(); i++)
rev[axes[i]] = i;
return rev;
}
Transpose Compose(const Transpose& lhs, const Transpose& rhs) {
if (lhs.empty())
return rhs;
if (rhs.empty())
return lhs;
CHECK_EQ(lhs.size(), rhs.size());
Transpose ret(lhs.size());
for (auto i = 0; i < ret.size(); ++i)
ret[i] = lhs[rhs[i]];
return ret;
}
bool IsIdentity(const Transpose& t) {
for (size_t i = 0; i < t.size(); ++i) {
if (t[i] != i)
return false;
}
return true;
}
mshadow::LayoutFlag ApplyTranspose(mshadow::LayoutFlag layout, const Transpose& axes) {
auto ret = mshadow::layoutFlag(ApplyTranspose(mshadow::toString(layout), axes));
CHECK_NE(ret, mshadow::kUNKNOWN);
return ret;
}
std::string ApplyTranspose(const std::string& layout, const Transpose& axes) {
std::string ret(layout.size(), ' ');
for (size_t i = 0; i < ret.size(); i++)
ret[i] = layout[axes[i]];
return ret;
}
Transpose FromTShape(const mxnet::TShape& s) {
Transpose ret(s.ndim());
std::copy(s.begin(), s.end(), ret.begin());
return ret;
}
Transpose FactorCommonTranspose(std::vector<Transpose>* axes) {
Transpose ret;
for (auto& t : *axes) {
if (IsIdentity(t))
continue;
if (IsIdentity(ret)) {
std::swap(t, ret);
continue;
}
auto rev = Reverse(ret);
t = Compose(t, rev);
}
return ret;
}
} // namespace alm
} // namespace mxnet