blob: 5cc180e8e2e39a8403580d187d78b8c61ea429c1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file infer_layout_util.h
* \brief Utility functions to alter the layouts of operators or replace primitive operators with
other expressions. This pass can be used for computing convolution in
custom layouts or other general weight pre-transformation.
*/
#ifndef TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_
#define TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_
#include <tvm/relay/expr.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/tir/data_layout.h>
#include <string>
#include <tuple>
#include "pattern_util.h"
namespace tvm {
namespace relay {
/*!
* \brief Returns a new layout where the subordinate factors are adjusted based on the tensor
* shape.
* \param old_layout The old layout before any transformation.
* \param old_shape The shape of the original tensor.
* \return The adjusted Layout.
*/
inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& old_layout,
const Array<tvm::PrimExpr>& old_shape) {
// For each subordinate axis
// 1) Find the corresponding dual axis.
// 2) Find the Index of this dual axis in old_layout.
// 3) Find the shape of the that axis in old_shape.
// 4) a) Adjust factor to 1, if that shape is 1. b) Else retain the factor.
std::string new_layout;
for (auto axis : src_layout->axes) {
if (!LayoutAxis::Get(axis).IsPrimal()) {
// 1) Find the corresponding dual axis
const auto& dual_axis = LayoutAxis::Get(axis).ToPrimal();
// 2) Find the index of this dual axis in old_layout
int old_axis = old_layout.IndexOf(dual_axis);
// 3) Find the shape of this index in old_shape
auto shape_val = old_shape[old_axis];
// 4) a) Check if this shape element is 1.
bool is_shape_one = false;
if (auto* shape_int = shape_val.as<IntImmNode>()) {
if (shape_int->value == 1) {
new_layout += "1";
is_shape_one = true;
}
}
// 4) b) If shape is not 1, retain the factor.
if (!is_shape_one) {
auto new_shape_val = src_layout.FactorOf(dual_axis);
new_layout += std::to_string(new_shape_val);
}
}
new_layout += LayoutAxis::Get(axis).name();
}
return Layout(new_layout);
}
/*!
* \brief Infer & correct function of node layout. See \p Layout for layout convention
* \param attrs The attribute of the node.
* \param new_in_layouts The layouts of input arguments after alter_op_layout.
* This can be undefined, which means we call this function before alternating
* any operators.
* \param old_in_layouts The layouts of input arguments before alter_op_layout.
* \param old_in_types The types of old input arguments.
* \return infered_layout An array of two elements that are inferred input layouts and
* inferred output layouts.
*/
using FInferCorrectLayout = runtime::TypedPackedFunc<Array<Array<Layout>>(
const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types)>;
/*! \brief take arbitrary input layout and copy to output */
inline Array<Array<Layout>> ElemwiseArbitraryLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
Layout ret;
if (new_in_layouts.defined()) {
CHECK_GE(new_in_layouts.size(), 1);
ret = new_in_layouts[0];
} else {
for (size_t i = 0; i < old_in_layouts.size(); ++i) {
if (old_in_layouts[i].defined()) {
ret = old_in_layouts[i];
break;
}
}
}
return Array<Array<Layout>>{Array<Layout>(old_in_layouts.size(), ret), {ret}};
}
/*! \brief Infer layout for binary broadcast operators */
inline Array<Array<Layout>> BinaryBroadcastLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
Array<Layout> layouts;
Array<Array<IndexExpr>> old_in_shapes;
for (auto old_in_t : old_in_types) {
CHECK(old_in_t.as<TensorTypeNode>());
old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
}
if (new_in_layouts.defined()) {
layouts.Assign(new_in_layouts.begin(), new_in_layouts.end());
} else {
layouts.Assign(old_in_layouts.begin(), old_in_layouts.end());
}
if (!layouts[0].defined() && !layouts[1].defined()) {
// both undefined, infer fails
return Array<Array<Layout>>{{Layout::Undef(), Layout::Undef()}, {Layout::Undef()}};
} else if (!layouts[0].defined() || !layouts[1].defined()) {
// only one is defined, use shape information to help infer
int defined_idx = layouts[0].defined() ? 0 : 1;
int undef_idx = 1 - defined_idx;
if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) {
layouts.Set(undef_idx, layouts[defined_idx].SubLayout(old_in_shapes[defined_idx].size() -
old_in_shapes[undef_idx].size(),
old_in_shapes[undef_idx].size()));
return Array<Array<Layout>>{layouts, {layouts[defined_idx]}};
} else {
// only know the tensor with smaller dimensions,
// so we cannot infer the final broadcasted output.
// fails in this case.
return Array<Array<Layout>>{{Layout::Undef(), Layout::Undef()}, {Layout::Undef()}};
}
} else if (layouts[0].defined() && layouts[1].defined() &&
(layouts[0].ndim() == 0 || layouts[1].ndim() == 0)) {
int scalar = layouts[0].ndim() == 0 ? 0 : 1;
return Array<Array<Layout>>{layouts, {layouts[1 - scalar]}};
} else {
// Set the layout of the larger dimension. If one dimension size is lower, we call expand dims
// while transforming layout.
int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1;
int small_idx = 1 - large_idx;
Layout ret = layouts[large_idx];
if (old_in_layouts[0].Equals(old_in_layouts[1])) {
// Support scenarios where original operands were of type [N, H, W, C] and [N, H, W, 1]
// In this case, we might have NCHW16c coming for 1 operand. However, the other operand does
// not have enough C dimension. To reuse broadcasting, we would want to use NCHW1c for the
// second operand. The following section of code walks through the layouts and shapes to
// perform that operation.
// a in NCHWC16c
// b in NHW1
// b = layout_transform(b) from NHW1 -> NCHW1c
// add(a, b)
auto old_small_shape = old_in_shapes[small_idx];
auto old_small_layout = old_in_layouts[small_idx];
auto new_small_layout =
AdjustSubordinateFactors(layouts[large_idx], old_small_layout, old_small_shape);
layouts.Set(small_idx, new_small_layout);
} else {
// Support scenarios where original operands were of type [N, H, W, C] and [C]. In this case,
// while transforming the layout, we expand dims to make C go to NHWC, and then use the
// modified layout of the first operator to call the layout transform. E.g.
// a in NCHWC16c
// b in C
// b = expand_dims(b) from C -> NHWC
// b = layout_transform(b) from NHWC -> NCHW16c
// add(a, b)
layouts.Set(small_idx, ret);
}
return Array<Array<Layout>>{layouts, {ret}};
}
}
/*!
* Call registered FInferCorrectLayout of an op.
* Parameters are the same as the parameters for FInferCorrectLayout
* Returns inferred_input_layout, inferred_output_layout, success
*/
static inline std::tuple<Array<Layout>, Array<Layout>, bool> InferCorrectLayouts(
const Call& call, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
static auto finfer_layout = Op::GetAttrMap<FInferCorrectLayout>("FInferCorrectLayout");
if (!call->op.as<OpNode>()) {
return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
}
Op op = Downcast<Op>(call->op);
if (finfer_layout.count(op)) {
Array<Array<Layout>> inferred_layouts;
inferred_layouts = finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_types);
CHECK_EQ(inferred_layouts.size(), 2)
<< "FInferCorrectLayout should return an array with size of 2";
for (auto x : inferred_layouts) {
for (auto y : x) {
if (!y.defined()) { // inference fails
return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
}
}
}
return std::make_tuple<>(inferred_layouts[0], inferred_layouts[1], true);
} else {
return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
}
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_