blob: 55971de17b03a2f826d1f086f8b2393e300c49c4 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import namedtuple
import tvm
from tvm import te
from tvm import relay
from tvm.relay import quantize as qtz
import mxnet as mx
from mxnet import gluon
import logging
import os
import tvm.testing
logging.basicConfig(level=logging.INFO)
Config = namedtuple(
"Config",
[
"model",
"nbit_input",
"dtype_input",
"nbit_output",
"dtype_output",
"global_scale",
"expected_acc",
],
)
def get_val_data(model_name, rec_val, batch_size, num_workers=4):
rec_val = os.path.expanduser(rec_val)
mean_rgb = [123.68, 116.779, 103.939]
std_rgb = [58.393, 57.12, 57.375]
def batch_fn(batch, ctx):
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
return data, label
img_size = 299 if model_name == "inceptionv3" else 224
val_data = mx.io.ImageRecordIter(
path_imgrec=rec_val,
preprocess_threads=num_workers,
shuffle=False,
batch_size=batch_size,
resize=256,
data_shape=(3, img_size, img_size),
mean_r=mean_rgb[0],
mean_g=mean_rgb[1],
mean_b=mean_rgb[2],
std_r=std_rgb[0],
std_g=std_rgb[1],
std_b=std_rgb[2],
)
return val_data, batch_fn
def get_model(model_name, batch_size, qconfig, target=None, original=False, simulated=False):
gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True)
img_size = 299 if model_name == "inceptionv3" else 224
data_shape = (batch_size, 3, img_size, img_size)
mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
net = mod["main"]
with tvm.transform.PassContext(opt_level=3):
qfunc = relay.quantize.prerequisite_optimize(net, params=params)
logging.debug("original")
logging.debug(qfunc.astext(show_meta_data=False))
if original:
return qfunc
with qconfig:
logging.debug("current quantize config")
logging.debug(qtz.current_qconfig())
qfunc = qtz.quantize(qfunc)
logging.debug("after quantize")
logging.debug(qfunc.astext(show_meta_data=False))
return qfunc
def eval_acc(model, dataset, batch_fn, target=tvm.target.cuda(), ctx=tvm.gpu(), log_interval=100):
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(model, target)
# create runtime module
m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
# setup evaluaiton metric
dataset.reset()
batch_size = dataset.batch_size
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
acc_top1.reset()
acc_top5.reset()
# Execute
for i, batch in enumerate(dataset):
data, label = batch_fn(batch, [mx.cpu(0)])
m.run(data=data[0].asnumpy())
out_arr = m.get_output(0)
acc_top1.update(label, [mx.nd.array(out_arr.asnumpy())])
acc_top5.update(label, [mx.nd.array(out_arr.asnumpy())])
if not (i + 1) % log_interval:
_, top1 = acc_top1.get()
_, top5 = acc_top5.get()
nsamples = (i + 1) * batch_size
logging.info("[%d samples] validation: acc-top1=%f acc-top5=%f", nsamples, top1, top5)
logging.info("[final] validation: acc-top1=%f acc-top5=%f", top1, top5)
return top1
@tvm.testing.requires_gpu
def test_quantize_acc(cfg, rec_val):
qconfig = qtz.qconfig(
skip_conv_layers=[0],
nbit_input=cfg.nbit_input,
nbit_weight=cfg.nbit_input,
global_scale=cfg.global_scale,
dtype_input=cfg.dtype_input,
dtype_weight=cfg.dtype_input,
dtype_activation=cfg.dtype_output,
debug_enabled_ops=None,
)
model = get_model(cfg.model, 32, qconfig, tvm.target.cuda())
val_data, batch_fn = get_val_data(cfg.model, rec_val=rec_val, batch_size=32)
acc = eval_acc(model, val_data, batch_fn)
assert acc > cfg.expected_acc
return acc
if __name__ == "__main__":
# TODO(for user): replace the line with the path to imagenet validation dataset
rec_val = "/scratch/tqchen/imagenet/val.rec"
results = []
configs = [
Config(
"mobilenetv2_1.0",
nbit_input=8,
dtype_input="int8",
nbit_output=32,
dtype_output="int32",
global_scale=4.0,
expected_acc=0.666,
),
Config(
"resnet18_v1",
nbit_input=8,
dtype_input="int8",
nbit_output=16,
dtype_output="int16",
global_scale=8.0,
expected_acc=0.692,
),
Config(
"resnet18_v1",
nbit_input=8,
dtype_input="int8",
nbit_output=32,
dtype_output="int32",
global_scale=8.0,
expected_acc=0.692,
),
Config(
"resnet34_v1",
nbit_input=8,
dtype_input="int8",
nbit_output=32,
dtype_output="int32",
global_scale=8.0,
expected_acc=0.733,
),
Config(
"resnet50_v1",
nbit_input=8,
dtype_input="int8",
nbit_output=32,
dtype_output="int32",
global_scale=8.0,
expected_acc=0.747,
),
Config(
"resnet101_v1",
nbit_input=8,
dtype_input="int8",
nbit_output=32,
dtype_output="int32",
global_scale=8.0,
expected_acc=0.756,
),
# TODO: need to fix accuracy
# Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=4.0),
]
for config in configs:
acc = test_quantize_acc(config, rec_val)
results.append((config, acc))
for res in results:
print(res)