blob: 42a1ea186dbf649619d3c76a56ee661013a10c01 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file pad.cc
* \brief Implementation of operator pad
*/
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op.h>
#include <tvm/tir/data_layout.h>
#include <tvm/tir/op.h>
#include <tvm/topi/elemwise.h>
#include <tvm/topi/nn.h>
#include <vector>
#include "../make_op.h"
#include "../op_common.h"
namespace tvm {
namespace relay {
// relay.nn.pad
TVM_REGISTER_NODE_TYPE(PadAttrs);
Array<Array<Layout>> PadInferCorrectLayout(const Attrs& attrs, const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
PadAttrs* params = const_cast<PadAttrs*>(attrs.as<PadAttrs>());
Layout ret_data;
// If new_in_layouts are defined, this code tries to modify the layout.
bool is_layout_modified = new_in_layouts.defined();
if (new_in_layouts.defined()) {
// Create a map of axis to param_width. For the new layout, a new param_width is generated using
// the map. The new layout is rejected, if the padding is happening along the axis which was
// split.
// 1) Create a map from axis to param_width using old layout.
std::map<std::string, tvm::Array<Integer>> axis_pad_width;
int index_counter = 0;
ICHECK_EQ(new_in_layouts.size(), 2);
ICHECK_EQ(old_in_layouts.size(), 2);
for (auto iter_var : old_in_layouts[0]->axes) {
const auto& old_layout_axis = LayoutAxis::Get(iter_var);
axis_pad_width.emplace(old_layout_axis.name(), params->pad_width[index_counter]);
index_counter++;
}
// 2) Create new pad width by walking over the new layout and using the map.
tvm::Array<tvm::Array<Integer>> new_pad_width;
for (auto iter_var : new_in_layouts[0]->axes) {
const auto& new_layout_axis = LayoutAxis::Get(iter_var);
auto axis_name = new_layout_axis.name();
if (axis_pad_width.count(axis_name) != 0 && new_layout_axis.IsPrimal()) {
// This is primal axis. So, directly use the original pad_width.
new_pad_width.push_back(axis_pad_width.at(axis_name));
} else {
// This is the axis that got split. So, check that pad_width was [0, 0] originally.
const auto& dual_axis = new_layout_axis.ToPrimal();
auto dual_axis_name = dual_axis.name();
ICHECK(axis_pad_width.count(dual_axis_name))
<< "Missing axis " << dual_axis << " in " << old_in_layouts[0].name();
new_pad_width.push_back(axis_pad_width.at(dual_axis_name));
// If any pad_width element is not zero, do not change the layout.
for (auto width : axis_pad_width.at(dual_axis_name)) {
if (auto* width_imm = width.as<IntImmNode>()) {
if (width_imm->value != 0) {
is_layout_modified = false;
}
} else {
is_layout_modified = false;
}
}
}
}
// If the above conditions satisfied, we can set the newly created pad_width and use the new
// layout.
if (is_layout_modified) {
ret_data = new_in_layouts[0];
params->pad_width = new_pad_width;
}
}
if (!is_layout_modified) {
if (old_in_layouts.defined()) {
ICHECK_EQ(old_in_layouts.size(), 2);
ret_data = old_in_layouts[0];
} else {
ret_data = Layout::Undef();
}
}
// The pad value is always a scalar
Layout ret_pad_value = Layout("1");
return Array<Array<Layout>>{{ret_data, ret_pad_value}, {ret_data}};
}
bool PadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types = [pad_data_type, pad_value_type, ret_type]
ICHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const PadAttrs* param = attrs.as<PadAttrs>();
ICHECK(param != nullptr);
// check that pad widths match lengths
ICHECK(data->shape.size() == param->pad_width.size())
<< "There should be as many pad width pairs as shape dimensions "
<< "but the shape has " << data->shape.size() << " dimensions "
<< "and there are " << param->pad_width.size() << " pad width pairs.";
// each pad width element should be a pair of positive integers
std::vector<IndexExpr> oshape;
for (size_t i = 0; i < param->pad_width.size(); i++) {
ICHECK(param->pad_width[i].size() == 2)
<< "Each pad width element should be a pair but at index " << i << " there are "
<< param->pad_width[i].size() << " elements.";
auto width1 = tir::as_const_int(param->pad_width[i][0]);
auto width2 = tir::as_const_int(param->pad_width[i][1]);
ICHECK(width1 != nullptr);
ICHECK(width2 != nullptr);
if (!data->shape[i].as<tir::AnyNode>()) {
auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2);
oshape.push_back(data->shape[i] + padding);
if (tir::as_const_int(data->shape[i])) {
ICHECK(topi::detail::GetConstInt(data->shape[i] + padding) >= 0)
<< "Output shape post padding should be positive but got " << data->shape[i] + padding;
}
} else {
oshape.push_back(data->shape[i]);
}
}
reporter->Assign(types[2], TensorType(Array<IndexExpr>(oshape), data->dtype));
return true;
}
Array<te::Tensor> PadCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<PadAttrs>();
ICHECK(param != nullptr);
auto pad_width = param->pad_width;
ICHECK(pad_width.size() == inputs[0].ndim() && pad_width[0].size() == 2) << "Illegal pad_width";
Array<IndexExpr> pad_before;
for (size_t i = 0; i < pad_width.size(); ++i) {
pad_before.push_back(pad_width[i][0]);
}
Array<IndexExpr> pad_after;
for (size_t i = 0; i < pad_width.size(); ++i) {
pad_after.push_back(pad_width[i][1]);
}
te::Tensor cast_pad_value = topi::cast(inputs[1], inputs[0]->dtype);
const PrimExpr& pad_value = cast_pad_value(Array<PrimExpr>());
return Array<te::Tensor>{topi::pad(inputs[0], pad_before, pad_after, pad_value, "T_pad",
topi::kElementWise, param->pad_mode)};
}
// Handler to create a call to the padding op used by front-end FFI
Expr MakePad(Expr data, Array<Array<Integer>> pad_width, Expr pad_value, String pad_mode) {
auto attrs = make_object<PadAttrs>();
attrs->pad_width = std::move(pad_width);
attrs->pad_mode = std::move(pad_mode);
static const Op& op = Op::Get("nn.pad");
return Call(op, {data, pad_value}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.pad").set_body_typed(MakePad);
RELAY_REGISTER_OP("nn.pad")
.describe(R"code(Pad for n-D tensor.
)code" TVM_ADD_FILELINE)
.set_attrs_type<PadAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("pad_val", "Tensor", "The value to fill the padded area with")
.set_support_level(2)
.add_type_rel("Pad", PadRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PadInferCorrectLayout)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>("FTVMCompute", PadCompute);
// relay.nn.mirror_pad
TVM_REGISTER_NODE_TYPE(MirrorPadAttrs);
bool MirrorPadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const MirrorPadAttrs* param = attrs.as<MirrorPadAttrs>();
ICHECK(param != nullptr);
// check that pad widths match lengths
ICHECK(data->shape.size() == param->pad_width.size())
<< "There should be as many pad width pairs as shape dimensions "
<< "but the shape has " << data->shape.size() << " dimensions "
<< "and there are " << param->pad_width.size() << " pad width pairs.";
// each pad width element should be a pair of positive integers
std::vector<IndexExpr> oshape;
for (size_t i = 0; i < param->pad_width.size(); i++) {
ICHECK(param->pad_width[i].size() == 2)
<< "Each pad width element should be a pair but at index " << i << " there are "
<< param->pad_width[i].size() << " elements.";
auto width1 = tir::as_const_int(param->pad_width[i][0]);
auto width2 = tir::as_const_int(param->pad_width[i][1]);
ICHECK(width1 != nullptr);
ICHECK(width2 != nullptr);
ICHECK(*width1 >= 0) << "Param width elements should be positive but first pad width at "
<< "index " << i << " is " << *width1 << ".";
ICHECK(*width2 >= 0) << "Param width elements should be positive but first pad width at "
<< "index " << i << " is " << *width2 << ".";
auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2);
oshape.push_back(data->shape[i] + padding);
}
reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), data->dtype));
return true;
}
// Handler to create a call to the padding op used by front-end FFI
Expr MakeMirrorPad(Expr data, Array<Array<IndexExpr>> pad_width, String mode) {
auto attrs = make_object<MirrorPadAttrs>();
attrs->mode = mode;
attrs->pad_width = std::move(pad_width);
static const Op& op = Op::Get("nn.mirror_pad");
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.mirror_pad").set_body_typed(MakeMirrorPad);
RELAY_REGISTER_OP("nn.mirror_pad")
.describe(R"code(MirrorPad for n-D tensor.
)code" TVM_ADD_FILELINE)
.set_attrs_type<MirrorPadAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("MirrorPad", MirrorPadRel)
.set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay
} // namespace tvm