blob: 923f4eb34391dc8f831abf45b50cec17fac623bf [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file alm.h
* \brief Automatic Layout Manager
* \author Dawid Tracz, Vladimir Cherepanov
*/
#ifndef MXNET_COMMON_ALM_H_
#define MXNET_COMMON_ALM_H_
#include <mxnet/base.h>
#include <nnvm/graph.h>
#include <nnvm/node.h>
#include <functional>
#include <string>
#include <unordered_map>
#include <vector>
namespace mxnet {
namespace alm {
/*!
* \brief A singleton flag, set and read by MXSetOptimizeLayout and MXGetOptimizeLayout
*/
struct ALMParams {
bool optimize = false;
static ALMParams& get() {
static ALMParams alm;
return alm;
}
};
/*!
* \bried Top-level function to run layout optimization.
*/
nnvm::Graph OptimizeLayout(nnvm::Graph&& g);
/*!
* \brief Transpose, represented by permutation of axes.
*/
using Transpose = std::vector<size_t>;
bool IsIdentity(const Transpose& t);
Transpose Reverse(const Transpose& axes);
/*!
* \bried Compose 2 transposes. Not commutative: a * b means b is applied first, then a.
*/
Transpose Compose(const Transpose& lhs, const Transpose& rhs);
mshadow::LayoutFlag ApplyTranspose(mshadow::LayoutFlag layout, const Transpose& axes);
std::string ApplyTranspose(const std::string& layout, const Transpose& axes);
Transpose FromTShape(const mxnet::TShape& s);
/*!
* \brief May change operator's layout. Used in LayoutOptimization.
*
* \param target_layout The target layout to change to, or kUNKNOWN. In the latter case the target
* layout is calculated based on in_axes, with a goal to cancel them out (at least some, ideally -
* all).
* \param in_axes (in/out) On input - pending inputs' transposes. On output - inputs' transposes,
* required by the new layout.
* \param out_axes (out) Outputs' transposes, required to convert to the original layouts.
* \return true if attrs changed and params need to be reparsed.
*/
using FChangeLayout = std::function<bool(nnvm::NodeAttrs*,
mshadow::LayoutFlag target_layout,
std::vector<Transpose>* in_axes,
std::vector<Transpose>* out_axes)>;
/*!
* \brief Factors out and returns a common transpose, or default-constructed Transpose if all
* axes (in/out parameter) are empty.
*/
Transpose FactorCommonTranspose(std::vector<Transpose>* axes);
} // namespace alm
} // namespace mxnet
#endif // MXNET_COMMON_ALM_H_