blob: 6dde61359df654a00913bb39a17804b6c694fc08 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/qnn/op/mul.cc
* \brief QNN mul operator.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../transforms/pattern_utils.h"
#include "../utils.h"
#include "op_common.h"
namespace tvm {
namespace relay {
namespace qnn {
/*
* \brief Canonicalizes the QNN mul op.
* \param attrs The QNN concatenate attrs.
* \param new_args The new mutated args to the call node.
* \param arg_types The types of input and output.
* \return The sequence of Relay ops for mul op.
*/
Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
Expr output;
// Get the attrs.
QnnBinaryOpArguments args(new_args);
// Get the input dtype and shape.
QnnBinaryOpTensorType input_type(arg_types, 0);
// data types
const auto int32_dtype = DataType::Int(32);
const auto float32_dtype = DataType::Float(32);
const auto* broadcast_attrs = attrs.as<BroadcastAttrs>();
ICHECK(broadcast_attrs != nullptr);
auto lhs_axis = broadcast_attrs->lhs_axis;
auto rhs_axis = broadcast_attrs->rhs_axis;
if (IsConstScalar(args.lhs_scale) && IsConstScalar(args.rhs_scale)) {
/*
This is per-tensor quantized multiply.
A tensor multiplication c = a * b can be written in terms of respective
quantized tensors, scales and zero points as
S_c * (Q_c - zp_c) = S_a * (Q_a - zp_a) * S_b * (Q_b - zp_b).
We can consider the product (Q_a - zp_a) * (Q_b - zp_b) as a different
quantized tensor of c, Q', with corresponding scale S' = S_a * S_b and zp' =
0. The quantized multiplication then becomes
Q_c = S'/S_c Q' + z_c,
which is essentially a requantization of tensor Q' into tensor Q_c.
*/
auto lhs_shifted = Cast(args.lhs, int32_dtype);
auto rhs_shifted = Cast(args.rhs, int32_dtype);
auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point);
}
if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
rhs_shifted = Subtract(rhs_shifted, args.rhs_zero_point);
}
// Create a new tensor Q'
output = Multiply(lhs_shifted, rhs_shifted);
// Get the adjusted new scale and zero points.
float lhs_scale_float = GetScalarFromConstant<float>(args.lhs_scale);
float rhs_scale_float = GetScalarFromConstant<float>(args.rhs_scale);
float new_scale_float = lhs_scale_float * rhs_scale_float;
auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float);
auto new_input_zero_point = zero_scalar;
// Requantize to get Q_c
output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point,
args.output_scale, args.output_zero_point, input_type.dtype);
} else if (lhs_axis == rhs_axis) {
/*
This is per-channel quantized multiply, assumming lhs_axis and rhs_axis are the same.
The subtract is done on the specified axis via broadcast. Then, we multiply lhs and rhs.
The output is requantized using new scale and axis. TODO: support different axes.
*/
auto lhs_data = Cast(args.lhs, int32_dtype);
auto rhs_data = Cast(args.rhs, int32_dtype);
auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
// Broadcast lhs zero point if needed
int rank = static_cast<int>(input_type.shape.size());
int axis = (lhs_axis < 0) ? ((rank > 0) ? rank + lhs_axis : 0) : lhs_axis;
Expr lhs_zero_broadcast = ExpandBiasToMatchAxis(Reshape(args.lhs_zero_point,
{
-1,
}),
rank, {axis});
lhs_data = Subtract(lhs_data, Cast(lhs_zero_broadcast, DataType::Int(32)));
}
if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
// Broadcast rhs zero point if needed
int rank = static_cast<int>(input_type.shape.size());
int axis = (rhs_axis < 0) ? ((rank > 0) ? rank + rhs_axis : 0) : rhs_axis;
Expr rhs_zero_broadcast = ExpandBiasToMatchAxis(Reshape(args.rhs_zero_point,
{
-1,
}),
rank, {axis});
rhs_data = Subtract(rhs_data, Cast(rhs_zero_broadcast, DataType::Int(32)));
}
// Create a new tensor Q'
output = Multiply(lhs_data, rhs_data);
// Requantize to get Q_c
auto lhs_scales = GetFloatVectorFromConstant(args.lhs_scale);
auto rhs_scales = GetFloatVectorFromConstant(args.rhs_scale);
std::vector<double> output_multipliers;
for (size_t i = 0; i < lhs_scales.size(); i++) {
double multiplier = static_cast<double>(lhs_scales[i]) * static_cast<double>(rhs_scales[i]);
output_multipliers.push_back(multiplier);
}
auto new_input_scale = MakeConstantTensor(
DataType::Float(32), {(int64_t)output_multipliers.size()}, output_multipliers);
output = Requantize(output, input_type.shape, new_input_scale, zero_scalar, args.output_scale,
args.output_zero_point, input_type.dtype, lhs_axis);
} else {
LOG(FATAL) << "Not supported: lhs_axis and rhs_axis are not the same.";
}
return output;
}
// QNN Multiplication operator.
QNN_REGISTER_BINARY_OP("mul")
.describe("Elementwise mul with broadcasting for quantized tensors.")
.set_support_level(11)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize);
} // namespace qnn
} // namespace relay
} // namespace tvm