blob: 3c2f6eb96d6be30a63674d7b0a8bf04309010433 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file realize.cc
*
* \brief Realizing the simulated graph into real low-precision
* graph.
*/
#include "./realize.h"
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/transform.h>
#include "../op/annotation/annotation.h"
#include "../qnn/utils.h"
#include "../transforms/fold_constant.h"
#include "./quantize.h"
namespace tvm {
namespace relay {
namespace quantize {
using namespace relay::transform;
Expr QRealizeIntExprNode::Realize() const {
Expr data = this->data;
// dequantize
data = Cast(data, DataType::Float(32));
data = Multiply(data, this->dom_scale);
return data;
}
QRealizeIntExpr::QRealizeIntExpr(Expr data, Expr dom_scale, DataType dtype) {
ObjectPtr<QRealizeIntExprNode> n = make_object<QRealizeIntExprNode>();
n->data = std::move(data);
n->dom_scale = std::move(dom_scale);
n->dtype = std::move(dtype);
data_ = std::move(n);
}
inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
return Call(ref_call->op, args, ref_call->attrs, ref_call->type_args);
}
/* calculate `data * s1 / s2`, use shift if possible */
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
const Array<IndexExpr>& data_shape) {
const QConfig& cfg = QConfig::Current();
// here we assume the dtype of data is dtype activation
if (s1 == s2) return data;
float factor = s1 / s2;
float shift_factor = std::log2(factor);
ICHECK_GT(shift_factor, 0);
if (static_cast<int>(shift_factor) == shift_factor) {
return LeftShift(data, MakeConstantScalar(dtype, static_cast<int>(shift_factor)));
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
if (cfg->rounding == "UPWARD") {
auto [fixed_point_multiplier, shift] = qnn::GetFixedPointMultiplierShift(factor);
data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
} else {
data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape);
}
return Cast(data, dtype);
}
}
Expr QuantizeRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
// do not handle data type cast
const auto param = ref_call->attrs.as<SimulatedQuantizeAttrs>();
ICHECK_EQ(param->rounding, "round");
Expr dom_scale = new_args[1];
Expr clip_min = new_args[2];
Expr clip_max = new_args[3];
float dom_scale_imm = GetScalarFromConstant<float>(dom_scale);
float clip_min_imm = GetScalarFromConstant<float>(clip_min);
float clip_max_imm = GetScalarFromConstant<float>(clip_max);
// x * idom_scale = y * odom_scale
// => y = x * idom_scale / odom_scale
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
// int32->int8
Expr data = n->data;
float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale);
float odom_scale_imm = GetScalarFromConstant<float>(dom_scale);
if (idom_scale_imm == odom_scale_imm) {
// same domain scale, only clip
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExpr(data, dom_scale, n->dtype);
}
float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
ICHECK_NE(shift_nbit, 0);
if (static_cast<int>(shift_nbit) == shift_nbit) {
if (shift_nbit > 0) {
// use right shift
if (cfg->round_for_shift) {
float round_bias = std::pow(2.0, shift_nbit - 1);
data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias)));
}
data = RightShift(data,
MakeConstantScalar(cfg->dtype_activation, static_cast<int>(shift_nbit)));
} else {
data = LeftShift(data,
MakeConstantScalar(cfg->dtype_activation, static_cast<int>(-shift_nbit)));
}
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExpr(data, dom_scale, n->dtype);
} else {
data = Cast(data, DataType::Int(64));
if (cfg->rounding == "UPWARD") {
auto [fixed_point_multiplier, shift] =
qnn::GetFixedPointMultiplierShift(idom_scale_imm / odom_scale_imm);
data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
} else {
data = qnn::FixedPointMultiplyToNearest(data, idom_scale_imm / odom_scale_imm,
ref_call->type_as<TensorTypeNode>()->shape);
}
data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype);
return QRealizeIntExpr(data, dom_scale, n->dtype);
}
}
// quantize from real
ICHECK(!new_args[0]->IsInstance<TempExprNode>());
Expr data = new_args[0];
Expr scaled_data = Multiply(data, MakeConstantScalar(DataType::Float(32), 1 / dom_scale_imm));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
return QRealizeIntExpr(round_data, dom_scale, DataType::Float(32));
}
RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
.set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
Expr Conv2dRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
ICHECK_EQ(new_args.size(), 2);
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
Expr ldata = lhs->data;
if (lhs->dtype != cfg->dtype_input) {
ldata = Cast(ldata, cfg->dtype_input);
}
Expr rdata = Cast(rhs->data, cfg->dtype_weight);
const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>();
auto attrs = make_object<Conv2DAttrs>();
*attrs = *ref_attrs;
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantExpr(mul);
return QRealizeIntExpr(ret, dom_scale, out_dtype);
}
ICHECK(!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
Expr Conv1dRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>()) {
return Expr(nullptr);
}
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
CHECK(lhs);
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
CHECK(rhs);
Expr ldata = lhs->data;
if (lhs->dtype != cfg->dtype_input) {
ldata = Cast(ldata, cfg->dtype_input);
}
Expr rdata = Cast(rhs->data, cfg->dtype_weight);
const auto ref_attrs = ref_call->attrs.as<Conv1DAttrs>();
auto attrs = make_object<Conv1DAttrs>();
*attrs = *ref_attrs;
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantExpr(mul);
return QRealizeIntExpr(ret, dom_scale, out_dtype);
}
RELAY_REGISTER_OP("nn.conv1d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv1dRealize);
Expr DenseRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
ICHECK_EQ(new_args.size(), 2);
if (!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>()) {
return Expr(nullptr);
}
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
Expr ldata = lhs->data;
if (lhs->dtype != cfg->dtype_input) {
ldata = Cast(ldata, cfg->dtype_input);
}
Expr rdata = Cast(rhs->data, cfg->dtype_weight);
const auto ref_attrs = ref_call->attrs.as<DenseAttrs>();
auto attrs = make_object<DenseAttrs>();
*attrs = *ref_attrs;
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantExpr(mul);
return QRealizeIntExpr(ret, dom_scale, out_dtype);
}
RELAY_REGISTER_OP("nn.dense").set_attr<FForwardRewrite>("FQRealizeRewrite", DenseRealize);
Expr MulRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
ICHECK_EQ(new_args.size(), 2);
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
// execute the operation with activation data type.
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
Expr ldata = lhs->data;
Expr rdata = rhs->data;
DataType dtype = cfg->dtype_activation;
if (lhs->dtype != dtype) {
ldata = Cast(ldata, dtype);
}
if (rhs->dtype != dtype) {
rdata = Cast(rdata, dtype);
}
Expr ret = ForwardOp(ref_call, {ldata, rdata});
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantExpr(mul);
return QRealizeIntExpr(ret, dom_scale, dtype);
}
ICHECK(!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("multiply").set_attr<FForwardRewrite>("FQRealizeRewrite", MulRealize);
float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
if (nptrs.size() == 2) {
// x = a * s1, y = b * s2
// x + y = (a * s1 / s2 + b) * s2, if s1 > s2
// = (a + b * s2 / s1) * s1, if s2 > s1
float s1 = GetScalarFromConstant<float>(nptrs[0]->dom_scale);
float s2 = GetScalarFromConstant<float>(nptrs[1]->dom_scale);
return s1 > s2 ? s2 : s1;
} else {
const QConfig& cfg = QConfig::Current();
float scale = cfg->global_scale;
return scale / std::pow(2.0, cfg->nbit_activation - 1);
}
}
/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
DataType* dtype_ptr, Expr* scale_ptr,
DataType dtype = DataType::Void()) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
const QConfig& cfg = QConfig::Current();
std::vector<const QRealizeIntExprNode*> nptrs;
Array<Expr> ret;
for (auto arg : args) {
const auto* nptr = arg.as<QRealizeIntExprNode>();
ICHECK(nptr);
nptrs.push_back(nptr);
ret.push_back(nptr->data);
}
// unify the data type
ICHECK_EQ(ref_args.size(), args.size());
if (dtype.is_void()) {
if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
dtype = cfg->dtype_input;
} else {
dtype = cfg->dtype_activation;
}
}
for (size_t i = 0; i < ret.size(); ++i) {
auto ref_arg = ref_args[i].as<CallNode>();
if (nptrs[i]->dtype != dtype) {
ret.Set(i, Cast(ret[i], dtype));
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
auto new_arg = Cast(ret[i], cfg->dtype_input);
new_arg = StopFusion(new_arg);
ret.Set(i, Cast(new_arg, dtype));
}
}
// unify the dom_scale
float s = ChooseDomScale(nptrs);
Expr dom_scale = MakeConstantScalar(DataType::Float(32), s);
for (size_t i = 0; i < ret.size(); ++i) {
float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype, ref_args[i]->type_as<TensorTypeNode>()->shape));
}
*dtype_ptr = dtype;
*scale_ptr = dom_scale;
return ret;
}
Expr AddRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
ICHECK_EQ(new_args.size(), 2);
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
// execute the operation with activation data type.
const QConfig& cfg = QConfig::Current();
Array<Expr> ret_args =
UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale, cfg->dtype_activation);
for (size_t i = 0; i < ret_args.size(); ++i) {
// do not fuse float32 arg
if (new_args[i].as<QRealizeIntExprNode>()->dtype == DataType::Float(32)) {
ret_args.Set(i, StopFusion(ret_args[i]));
}
}
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExpr(ret, dom_scale, dtype);
}
ICHECK(!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("add").set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
Expr ClipRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
ICHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
const auto ref_attrs = ref_call->attrs.as<ClipAttrs>();
auto attrs = make_object<ClipAttrs>();
double dom_scale = GetScalarFromConstant<float>(n->dom_scale);
attrs->a_min = ref_attrs->a_min / dom_scale;
attrs->a_max = ref_attrs->a_max / dom_scale;
Expr ret = Call(ref_call->op, {n->data}, Attrs(attrs), ref_call->type_args);
return QRealizeIntExpr(ret, n->dom_scale, n->dtype);
}
ICHECK(!new_args[0]->IsInstance<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("clip").set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);
Expr ConcatenateRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
ICHECK_EQ(new_args.size(), 1);
ICHECK_EQ(ref_call->args.size(), 1);
const auto* tuple = new_args[0].as<TupleNode>();
const auto* ref_tuple = ref_call->args[0].as<TupleNode>();
ICHECK(tuple);
ICHECK(ref_tuple);
const Array<Expr>& arr = tuple->fields;
const Array<Expr>& ref_arr = ref_tuple->fields;
if (arr[0].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, {Tuple(ret_args)});
return QRealizeIntExpr(ret, dom_scale, dtype);
} else {
for (auto arg : new_args) {
ICHECK(!arg->IsInstance<TempExprNode>());
}
return Expr(nullptr);
}
}
RELAY_REGISTER_OP("concatenate").set_attr<FForwardRewrite>("FQRealizeRewrite", ConcatenateRealize);
/* \brief forward the original operator */
Expr IdentityRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
ICHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr ret = ForwardOp(ref_call, {n->data});
return QRealizeIntExpr(ret, n->dom_scale, n->dtype);
}
ICHECK(!new_args[0]->IsInstance<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("reshape").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("nn.batch_flatten")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("transpose").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("annotation.stop_fusion")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
/* \brief for unary operators which requantize its input to dtype_nbit */
Expr CastDtypeInputRealize(const Call& ref_call, const Array<Expr>& new_args,
const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
ICHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr data = Cast(n->data, cfg->dtype_input);
Expr ret = ForwardOp(ref_call, {data});
return QRealizeIntExpr(ret, n->dom_scale, cfg->dtype_input);
}
ICHECK(!new_args[0]->IsInstance<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.max_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
RELAY_REGISTER_OP("nn.max_pool1d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
Expr AvgPoolRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
ICHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr data = n->data;
if (n->dtype != cfg->dtype_activation) {
data = Cast(n->data, cfg->dtype_activation);
}
Expr ret = ForwardOp(ref_call, {data});
return QRealizeIntExpr(ret, n->dom_scale, cfg->dtype_activation);
}
ICHECK(!new_args[0]->IsInstance<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.avg_pool2d").set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
RELAY_REGISTER_OP("nn.global_avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
Expr CastHintRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const auto param = ref_call->attrs.as<CastHintAttrs>();
ICHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr ret = Cast(n->data, param->dtype);
return QRealizeIntExpr(ret, n->dom_scale, param->dtype);
}
ICHECK(!new_args[0]->IsInstance<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("annotation.cast_hint")
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastHintRealize);
Expr BatchMatmulRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
ICHECK_EQ(new_args.size(), 2);
if (!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>()) {
return Expr(nullptr);
}
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
Expr ldata = lhs->data;
Expr rdata = rhs->data;
DataType dtype_input = cfg->dtype_input;
DataType dtype_weight = cfg->dtype_weight;
if (lhs->dtype != dtype_input) {
ldata = Cast(ldata, dtype_input);
}
if (rhs->dtype != dtype_weight) {
rdata = Cast(rdata, dtype_weight);
}
const auto ref_attrs = ref_call->attrs.as<BatchMatmulAttrs>();
auto attrs = make_object<BatchMatmulAttrs>();
*attrs = *ref_attrs;
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantExpr(mul);
return QRealizeIntExpr(ret, dom_scale, out_dtype);
}
RELAY_REGISTER_OP("nn.batch_matmul")
.set_attr<FForwardRewrite>("FQRealizeRewrite", BatchMatmulRealize);
Pass QuantizeRealizePass() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
};
return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
}
TVM_REGISTER_GLOBAL("relay._quantize.QuantizeRealize").set_body_typed(QuantizeRealizePass);
} // namespace quantize
} // namespace relay
} // namespace tvm