blob: 56bf708f5e060a9b83e3e8fd70ac6e45ba432001 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "op_common.h"
#include <algorithm>
namespace tvm {
namespace relax {
Array<Expr> GetCallArgs(const Call& call) {
static const Op& call_tir_op = Op::Get("relax.call_tir");
Array<Expr> args;
if (call->op.same_as(call_tir_op)) {
args = Downcast<Tuple>(call->args[1])->fields;
} else {
args = call->args;
}
return args;
}
void CheckNumArguments(const Call& call, const BlockBuilder& ctx) {
Op op = Downcast<Op>(call->op);
int expected_input = op->arguments.size();
if (static_cast<int>(call->args.size()) != expected_input) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Operator " << op << " expects " << expected_input << " arguments"
<< ", but was called with " << call->args.size() << " arguments");
}
}
TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const BlockBuilder& ctx) {
Op op = Downcast<Op>(call->op);
ICHECK_EQ(op->arguments.size(), call->args.size())
<< "Failure caught by this check "
<< "should have previously been caught by `CheckNumArguments`";
ICHECK_LT(i_arg, op->arguments.size());
auto arg = call->args[i_arg];
auto sinfo = GetStructInfo(arg);
if (auto tensor_sinfo = sinfo.as<TensorStructInfo>()) {
return tensor_sinfo.value();
} else {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Operator " << op << " requires argument " << i_arg << " ("
<< op->arguments[i_arg]->name << ") to be a tensor. "
<< "However, the argument " << arg << " is instead of type " << sinfo);
// Unreachable, but [[noreturn]] attribute on virtual function
// `ReportFatal` is insufficient to silence -Wreturn-type, as
// child class might not be [[noreturn]].
return TensorStructInfo();
}
}
Array<TensorStructInfo> GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) {
CheckNumArguments(call, ctx);
Op op = Downcast<Op>(call->op);
Array<TensorStructInfo> input_tensor_sinfo;
for (size_t i = 0; i < call->args.size(); ++i) {
input_tensor_sinfo.push_back(GetInputTensorStructInfo(call, i, ctx));
}
return input_tensor_sinfo;
}
Array<TensorStructInfo> GetTensorStructInfoFromTuple(const Call& call, const BlockBuilder& ctx,
const Expr& tup) {
const auto* tuple_sinfo = GetStructInfoAs<TupleStructInfoNode>(tup);
if (tuple_sinfo == nullptr) {
ctx->ReportFatal(Diagnostic::Error(call)
<< call->op
<< " expects the input to be a Tuple of Tensors. However, the given input is "
<< tup->struct_info_->GetTypeKey());
}
Array<TensorStructInfo> tensor_sinfo;
tensor_sinfo.reserve(tuple_sinfo->fields.size());
for (StructInfo field_sinfo : tuple_sinfo->fields) {
const auto* field_tensor_sinfo = field_sinfo.as<TensorStructInfoNode>();
if (field_tensor_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< call->op << " expects the input to be a Tuple of Tensors. However, the given input is "
<< tup->struct_info_);
}
tensor_sinfo.push_back(GetRef<TensorStructInfo>(field_tensor_sinfo));
}
return tensor_sinfo;
}
Optional<Array<PrimExpr>> InferBinaryBroadcastShape(const Call& call, const BlockBuilder& ctx,
const Array<PrimExpr>& x1_shape,
const Array<PrimExpr>& x2_shape) {
arith::Analyzer* analyzer = ctx->GetAnalyzer();
int x1_ndim = x1_shape.size();
int x2_ndim = x2_shape.size();
int max_ndim = std::max(x1_ndim, x2_ndim);
std::vector<PrimExpr> output_shape;
output_shape.reserve(max_ndim);
int i = 1;
for (; i <= std::min(x1_ndim, x2_ndim); ++i) {
const PrimExpr& dim0 = x1_shape[x1_ndim - i];
const PrimExpr& dim1 = x2_shape[x2_ndim - i];
const auto* int_dim0 = dim0.as<IntImmNode>();
const auto* int_dim1 = dim1.as<IntImmNode>();
if (int_dim0 != nullptr && int_dim0->value == 1) {
output_shape.push_back(dim1);
} else if (int_dim1 != nullptr && int_dim1->value == 1) {
output_shape.push_back(dim0);
} else if (analyzer->CanProveEqual(dim0, dim1)) {
output_shape.push_back(dim0);
} else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "In " << call->op << ", the first input shape at dim " << x1_ndim - i
<< " is " << dim0 << " and the second input shape at dim " << x2_ndim - i
<< " is " << dim1 << ", which are not broadcastable.");
} else {
// Use simple fallback when shape mismatch.
return NullOpt;
}
}
auto& longer_shape = (x1_ndim > x2_ndim) ? x1_shape : x2_shape;
for (; i <= max_ndim; ++i) {
output_shape.push_back(longer_shape[max_ndim - i]);
}
return Array<PrimExpr>(output_shape.rbegin(), output_shape.rend());
}
std::vector<int> NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim,
const Array<Integer>& axes) {
ICHECK_NE(ndim, kUnknownNDim) << "The ndim is required to be known for this function.";
std::vector<bool> appeared_dims_set;
std::vector<int> axes_non_neg;
appeared_dims_set.resize(ndim, /*value=*/false);
axes_non_neg.reserve(axes.size());
for (const Integer& axis : axes) {
int _axis = axis->value;
if (_axis < -ndim || _axis >= ndim) {
ctx->ReportFatal(Diagnostic::Error(call) << "In " << call->op << ", the input axis " << _axis
<< " is out of range. The input tensor has " << ndim
<< " dimensions, so axis should be in range ["
<< -ndim << ", " << ndim << ").");
} else if (_axis < 0) {
_axis = ndim + _axis;
}
if (appeared_dims_set[_axis]) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "In " << call->op
<< ", the input axes is required to be non-repetitive. However, there are "
"multiple given axes referring to axis "
<< _axis);
}
appeared_dims_set[_axis] = true;
axes_non_neg.push_back(_axis);
}
return axes_non_neg;
}
InferLayoutOutput InferLayoutUnaryEwise(const Call& call,
const Map<String, Array<String>>& desired_layouts,
const VarLayoutMap& var_layout_map) {
ICHECK(NoDesiredLayout(call, desired_layouts));
LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
return InferLayoutOutput({layout}, {layout}, Attrs(call->attrs));
}
} // namespace relax
} // namespace tvm