blob: ff6c8ea5c1875af301b1376fad45d178c72e1838 [file] [log] [blame]
/*!
* Copyright (c) 2018 by Contributors
*
* \file quantize.cc
*
* \brief transform a graph to a low-bit graph
* for compression and acceleration.
*/
#include <dmlc/thread_local.h>
#include <tvm/base.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <cmath>
#include <string>
#include <vector>
#include <stack>
#include "pattern_util.h"
#include "quantize.h"
namespace tvm {
namespace relay {
namespace quantize {
/*! \brief Attribute for simulated quantize operator */
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
int kind;
bool sign;
std::string rounding;
TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
TVM_ATTR_FIELD(kind)
.describe("kind of field, hint for nbit/dtype configuration.");
TVM_ATTR_FIELD(sign).set_default(true)
.describe("whether to use signed data type.");
TVM_ATTR_FIELD(rounding).set_default("round")
.describe("rounding mode. Can be 'floor', 'ceil', 'round'");
}
};
TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs);
bool SimulatedQuantizeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 5);
const auto param = attrs.as<SimulatedQuantizeAttrs>();
CHECK(param != nullptr);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
reporter->Assign(types[1], TensorTypeNode::make({}, Float(32))); // dom_scale
reporter->Assign(types[2], TensorTypeNode::make({}, Float(32))); // clip_min
reporter->Assign(types[3], TensorTypeNode::make({}, Float(32))); // clip_max
reporter->Assign(types[4], types[0]); // output
return true;
}
RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
.describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE)
.set_num_inputs(4)
.add_argument("data", "Tensor", "The input data.")
.add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar")
.add_argument("clip_min", "Tensor", "lower bound. It should be a scalar")
.add_argument("clip_max", "Tensor", "upper bound. It should be a scalar")
.set_attrs_type_key("relay.attrs.SimulatedQuantizeAttrs")
.set_support_level(10)
.add_type_rel("SimulatedQuantize", SimulatedQuantizeRel);
TVM_REGISTER_API("relay._quantize.simulated_quantize")
.set_body_typed<Expr(Expr, Expr, Expr, Expr, int, bool, std::string)>(
[](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max,
int kind, bool sign, std::string rounding) {
auto attrs = make_node<SimulatedQuantizeAttrs>();
attrs->kind = kind;
attrs->sign = sign;
attrs->rounding = rounding;
static const Op& op = Op::Get("relay.op.annotation.simulated_quantize");
return CallNode::make(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {});
});
// =============
// annotate pass
Expr QAnnotateExprNode::Realize() const {
const auto& cfg = QConfig::Current();
if (cfg->store_lowbit_output) {
// store low bit output back for VTA
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
return (*f)(this->expr, static_cast<int>(kQInput));
} else {
return expr;
}
}
QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) {
auto rnode = make_node<QAnnotateExprNode>();
rnode->expr = expr;
rnode->kind = kind;
return QAnnotateExpr(rnode);
}
TVM_REGISTER_API("relay._quantize.make_annotate_expr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = QAnnotateExprNode::make(args[0],
static_cast<QAnnotateKind>(args[1].operator int()));
});
TVM_REGISTER_API("relay._quantize.annotate")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr);
});
// =============
// realize pass
Expr QRealizeIntExprNode::Realize() const {
const auto& cfg = QConfig::Current();
Expr data = this->data;
if (cfg->store_lowbit_output) {
data = Cast(data, cfg->dtype_input);
}
// dequantize
data = Cast(data, Float(32));
data = Multiply(data, this->dom_scale);
return data;
}
QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) {
NodePtr<QRealizeIntExprNode> n = make_node<QRealizeIntExprNode>();
n->data = std::move(data);
n->dom_scale = std::move(dom_scale);
n->dtype = std::move(dtype);
return QRealizeIntExpr(n);
}
inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
return CallNode::make(ref_call->op,
args, ref_call->attrs, ref_call->type_args);
}
/* calculate `data * s1 / s2`, use shift if possible */
inline Expr MulAndDiv(Expr data, float s1, float s2) {
// here we assume the dtype of data is dtype activation
const QConfig& cfg = QConfig::Current();
if (s1 == s2) return data;
float factor = s1 / s2;
float shift_factor = std::log2(factor);
CHECK_GT(shift_factor, 0);
if (static_cast<int>(shift_factor) == shift_factor) {
return LeftShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_factor)));
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(cfg->dtype_activation, factor));
} else {
LOG(FATAL) << "fall back to float computation";
data = Cast(data, Float(32));
return Multiply(data, MakeConstantScalar(Float(32), factor));
}
}
Expr QuantizeRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
// do not handle data type cast
const auto param = ref_call->attrs.as<SimulatedQuantizeAttrs>();
CHECK_EQ(param->rounding, "round");
Expr dom_scale = new_args[1];
Expr clip_min = new_args[2];
Expr clip_max = new_args[3];
float dom_scale_imm = GetScalarFromConstant<float>(dom_scale);
float clip_min_imm = GetScalarFromConstant<float>(clip_min);
float clip_max_imm = GetScalarFromConstant<float>(clip_max);
// x * idom_scale = y * odom_scale
// => y = x * idom_scale / odom_scale
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
// int32->int8
Expr data = n->data;
float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale);
float odom_scale_imm = GetScalarFromConstant<float>(dom_scale);
if (idom_scale_imm == odom_scale_imm) {
// same domain scale, only clip
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
}
float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
CHECK_GT(shift_nbit, 0);
if (static_cast<int>(shift_nbit) == shift_nbit) {
// use right shift
if (cfg->round_for_shift) {
float round_bias = std::pow(2.0, shift_nbit - 1);
data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias)));
}
data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
} else {
// float computation
data = Cast(data, Float(32));
Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
}
}
// quantize from real
CHECK(!new_args[0]->derived_from<TempExprNode>());
Expr data = new_args[0];
Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
}
RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
.set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
Expr Conv2dRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>()) {
return Expr(nullptr);
}
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
CHECK(lhs);
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
CHECK(rhs);
Expr ldata = lhs->data;
if (lhs->dtype != cfg->dtype_input) {
ldata = Cast(ldata, cfg->dtype_input);
}
Expr rdata = Cast(rhs->data, cfg->dtype_weight);
const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>();
auto attrs = make_node<Conv2DAttrs>();
*attrs = *ref_attrs;
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
Expr ret = CallNode::make(ref_call->op,
{ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale));
return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
}
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
Expr MulRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
// execute the operation with activation data type.
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
Expr ldata = lhs->data;
Expr rdata = rhs->data;
DataType dtype = cfg->dtype_activation;
if (lhs->dtype == Float(32)) {
ldata = Cast(ldata, dtype);
} else {
CHECK_EQ(lhs->dtype, dtype);
}
if (rhs->dtype == Float(32)) {
rdata = Cast(rdata, dtype);
} else {
CHECK_EQ(rhs->dtype, dtype);
}
Expr ret = ForwardOp(ref_call, {ldata, rdata});
Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale));
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("multiply")
.set_attr<FForwardRewrite>("FQRealizeRewrite", MulRealize);
float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
if (nptrs.size() == 2) {
// x = a * s1, y = b * s2
// x + y = (a * s1 / s2 + b) * s2, if s1 > s2
// = (a + b * s2 / s1) * s1, if s2 > s1
float s1 = GetScalarFromConstant<float>(nptrs[0]->dom_scale);
float s2 = GetScalarFromConstant<float>(nptrs[1]->dom_scale);
return s1 > s2 ? s2 : s1;
} else {
const QConfig& cfg = QConfig::Current();
float scale = cfg->global_scale;
return scale / std::pow(2.0, cfg->nbit_activation - 1);
}
}
/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& args,
DataType* dtype_ptr,
Expr* scale_ptr) {
const QConfig& cfg = QConfig::Current();
std::vector<const QRealizeIntExprNode*> nptrs;
Array<Expr> ret;
for (auto arg : args) {
const auto* nptr = arg.as<QRealizeIntExprNode>();
CHECK(nptr);
nptrs.push_back(nptr);
ret.push_back(nptr->data);
}
// unify the data type
DataType dtype = cfg->dtype_activation;
for (size_t i = 0; i < ret.size(); ++i) {
if (nptrs[i]->dtype != dtype) {
ret.Set(i, Cast(ret[i], dtype));
}
}
// unify the dom_scale
float s = ChooseDomScale(nptrs);
Expr dom_scale = MakeConstantScalar(Float(32), s);
for (size_t i = 0; i < ret.size(); ++i) {
float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
ret.Set(i, MulAndDiv(ret[i], cur_s, s));
}
*dtype_ptr = dtype;
*scale_ptr = dom_scale;
return ret;
}
Expr AddRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 2);
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(new_args, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("add")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
Expr ConcatenateRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
const auto* tuple = new_args[0].as<TupleNode>();
CHECK(tuple);
const Array<Expr>& arr = tuple->fields;
if (arr[0].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(arr, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
} else {
for (auto arg : new_args) {
CHECK(!arg->derived_from<TempExprNode>());
}
return Expr(nullptr);
}
}
RELAY_REGISTER_OP("concatenate")
.set_attr<FForwardRewrite>("FQRealizeRewrite", ConcatenateRealize);
/* \brief forward the original operator */
Expr IdentityRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr ret = ForwardOp(ref_call, {n->data});
return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("strided_slice")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
Expr MaxPoolRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr data = Cast(n->data, cfg->dtype_input);
Expr ret = ForwardOp(ref_call, {data});
return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.max_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", MaxPoolRealize);
Expr AvgPoolRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr data = n->data;
if (n->dtype != cfg->dtype_activation) {
data = Cast(n->data, cfg->dtype_activation);
}
Expr ret = ForwardOp(ref_call, {data});
return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
TVM_REGISTER_API("relay._quantize.realize")
.set_body_typed<Expr(Expr)>([](const Expr& e) {
Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr);
return ret;
});
// =============
// qconfig
QConfig qconfig() {
return QConfig(make_node<QConfigNode>());
}
/*! \brief Entry to hold the BuildConfig context stack. */
struct TVMQConfigThreadLocalEntry {
/*! \brief The default build config if the stack is empty */
QConfig default_config;
/*! \brief The current build config context */
std::stack<QConfig> context_stack;
TVMQConfigThreadLocalEntry() :
default_config(qconfig()) {
}
};
/*! \brief Thread local store to hold the BuildConfig context stack. */
typedef dmlc::ThreadLocalStore<TVMQConfigThreadLocalEntry> TVMQConfigThreadLocalStore;
void QConfig::EnterQConfigScope(const QConfig& build_config) {
TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get();
entry->context_stack.push(build_config);
}
void QConfig::ExitQConfigScope() {
TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get();
entry->context_stack.pop();
}
QConfig& QConfig::Current() {
TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get();
if (entry->context_stack.size() > 0) {
return entry->context_stack.top();
}
return entry->default_config;
}
TVM_REGISTER_NODE_TYPE(QConfigNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<QConfigNode>([](const QConfigNode *op, IRPrinter *p) {
p->stream << "qconfig(";
p->stream << "nbit_input=" << op->nbit_input << ", ";
p->stream << "nbit_weight=" << op->nbit_weight << ", ";
p->stream << "nbit_activation=" << op->nbit_activation << ", ";
p->stream << "global_scale=" << op->global_scale << ", ";
p->stream << "skip_k_conv==" << op->skip_k_conv << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
p->stream << ")";
});
TVM_REGISTER_API("relay._quantize._GetCurrentQConfig")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = QConfig::Current();
});
TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
QConfig target = args[0];
QConfig::EnterQConfigScope(target);
});
TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
QConfig::ExitQConfigScope();
});
} // namespace quantize
} // namespace relay
} // namespace tvm