blob: beaeeb99992341fd66e9132a2c35f78775ffba10 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Tests on quantized torch model conversion """
import os
import numpy as np
import torch
import tvm
import tvm.testing
from PIL import Image
from torch import nn
from torch.quantization import (
DeQuantStub,
QuantStub,
QuantWrapper,
fuse_modules,
get_default_qat_qconfig,
prepare_qat,
)
from tvm import relay
from tvm.contrib.download import download_testdata
from tvm.relay.frontend.pytorch_utils import is_version_greater_than
from tvm.relay.op.contrib.register import get_pattern_table, register_pattern_table
def torch_version_check():
from packaging import version
return version.parse(torch.__version__) > version.parse("1.4.0")
def get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=False, target="llvm"):
input_shapes = [(input_name, ishape)]
with tvm.testing.disable_span_filling():
mod, params = relay.frontend.from_pytorch(
script_module, input_shapes, keep_quantized_weight=keep_quantized_weight
)
with tvm.testing.enable_span_filling():
mod_with_span, _ = relay.frontend.from_pytorch(
script_module, input_shapes, keep_quantized_weight=keep_quantized_weight
)
assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
if keep_quantized_weight:
for p in params.values():
assert p.dtype in ["int8", "int32"]
with tvm.transform.PassContext(opt_level=3):
# test on only cpu for now, torch cannot run quant models on cuda
# also not to make CI too slow
lib = relay.build(mod, target=target, params=params)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.device(target, 0)))
return runtime
def get_qconfig(per_channel):
from torch.quantization.observer import (
MovingAverageMinMaxObserver,
default_weight_observer,
)
if per_channel:
return torch.quantization.get_default_qconfig("fbgemm")
else:
act = MovingAverageMinMaxObserver.with_args(reduce_range=False)
return torch.quantization.QConfig(activation=act, weight=default_weight_observer)
def quantize_model(model, inp, per_channel=False):
model.fuse_model()
model.qconfig = get_qconfig(per_channel)
torch.quantization.prepare(model, inplace=True)
model(inp)
torch.quantization.convert(model, inplace=True)
class ConvBn(nn.Module):
def __init__(self, with_relu=False):
super().__init__()
layers = [nn.Conv2d(3, 32, 3, bias=True), nn.BatchNorm2d(32)]
if with_relu:
layers.append(nn.ReLU())
self.conv = nn.Sequential(*layers)
self.quant_wrap = QuantWrapper(self.conv)
self.with_relu = with_relu
def forward(self, x):
return self.quant_wrap(x)
def fuse_model(self):
indices = ["0", "1"]
if self.with_relu:
indices.append("2")
fuse_modules(self.conv, indices, inplace=True)
class ConvTranspose(nn.Module):
def __init__(self):
super().__init__()
layers = [nn.ConvTranspose2d(3, 32, 3, bias=True)]
self.conv = nn.Sequential(*layers)
self.quant_wrap = QuantWrapper(self.conv)
def forward(self, x):
return self.quant_wrap(x)
def fuse_model(self):
pass
class Linear(nn.Module):
def __init__(self, with_relu=False):
super().__init__()
layers = [nn.Linear(16, 32)]
if with_relu:
layers.append(nn.ReLU())
self.fc = nn.Sequential(*layers)
self.quant_wrap = QuantWrapper(self.fc)
self.with_relu = with_relu
def forward(self, x):
return self.quant_wrap(x)
def fuse_model(self):
if self.with_relu:
fuse_modules(self.fc, ["0", "1"], inplace=True)
class ReLU(nn.Module):
def __init__(self):
super().__init__()
self.relu = QuantWrapper(nn.ReLU())
def forward(self, x):
return self.relu(x)
def fuse_model(self):
pass
class LeakyReLU(nn.Module):
def __init__(self):
super().__init__()
self.leaky_relu = QuantWrapper(nn.LeakyReLU())
def forward(self, x):
return self.leaky_relu(x)
def fuse_model(self):
pass
# Mobilenet V3 related modules
class Hsigmoid(nn.Module):
def __init__(self, add_stub=False):
super().__init__()
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.add_stub = add_stub
self.hsigmoid = nn.Hardsigmoid()
def forward(self, x):
if self.add_stub:
x = self.quant(x)
x = self.hsigmoid(x)
if self.add_stub:
x = self.dequant(x)
return x
def fuse_model(self):
pass
class Hswish(nn.Module):
def __init__(self, add_stub=False):
super().__init__()
self.hswish = QuantWrapper(nn.Hardswish())
def forward(self, x):
return self.hswish(x)
def fuse_model(self):
pass
class SqueezeExcite(nn.Module):
def __init__(self, channel, reduction=4, add_stub=False):
super(SqueezeExcite, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
Hsigmoid(add_stub=False),
)
self.fmul = nn.quantized.FloatFunctional()
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.add_stub = add_stub
def forward(self, x):
b, c, _, _ = x.size()
if self.add_stub:
x = self.quant(x)
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
out = self.fmul.mul(x, y.expand_as(x))
if self.add_stub:
return self.dequant(out)
else:
return out
def fuse_model(self):
fuse_modules(self.fc, ["0", "1"], inplace=True)
# test on quantized::mul_scalar with negative scale
class MulScalarNegative(nn.Module):
def __init__(self):
super().__init__()
self.float_op = nn.quantized.FloatFunctional()
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
mul = self.float_op.mul_scalar(x, -0.3)
return self.dequant(mul)
def fuse_model(self):
pass
class UpsamplingBilinear(nn.Module):
def __init__(self):
super().__init__()
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
upsample = nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
return self.dequant(upsample)
def fuse_model(self):
pass
class AvgPool2d(nn.Module):
def __init__(self):
super().__init__()
self.pool = QuantWrapper(nn.AvgPool2d(kernel_size=2))
def forward(self, x):
return self.pool(x)
def fuse_model(self):
pass
class AdaptiveAvgPool2d(nn.Module):
def __init__(self):
super().__init__()
self.pool = QuantWrapper(nn.AdaptiveAvgPool2d((1, 1)))
def forward(self, x):
return self.pool(x)
def fuse_model(self):
pass
def test_quantized_modules():
imagenet_ishape = (1, 3, 224, 224)
qmodules = [
("relu", imagenet_ishape, ReLU(), False),
("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False),
("avgpool", imagenet_ishape, AvgPool2d(), False),
]
for per_channel in [False, True]:
if per_channel:
postfix = ", per_channel"
else:
postfix = ""
qmodules += [
("conv_bn" + postfix, imagenet_ishape, ConvBn(), per_channel),
("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel),
("linear" + postfix, (16, 16), Linear(), per_channel),
("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel),
("conv_transpose", imagenet_ishape, ConvTranspose(), False),
("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False),
("hswish", imagenet_ishape, Hswish(), False),
("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False),
("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), True),
("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False),
("leaky_relu", imagenet_ishape, LeakyReLU(), False),
]
for (module_name, ishape, raw_module, per_channel) in qmodules:
raw_module.eval()
inp = torch.rand(ishape)
# quantized conv_transpose2d is supported only with qnnpack engine before torch v1.8.0.
if module_name == "conv_transpose" and not is_version_greater_than("1.7.1"):
prev_engine = torch.backends.quantized.engine
torch.backends.quantized.engine = "qnnpack"
quantize_model(raw_module, inp, per_channel=per_channel)
torch.backends.quantized.engine = prev_engine
else:
quantize_model(raw_module, inp, per_channel=per_channel)
script_module = torch.jit.trace(raw_module, inp).eval()
with torch.no_grad():
pt_result = script_module(inp.clone()).numpy()
input_name = "input"
runtime = get_tvm_runtime(script_module, input_name, ishape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
tvm_result = runtime.get_output(0).numpy()
max_abs_diff = np.max(np.abs(tvm_result - pt_result))
mean_abs_diff = np.mean(np.abs(tvm_result - pt_result))
num_identical = np.sum(tvm_result == pt_result)
match_ratio = num_identical / float(np.prod(tvm_result.shape))
print(module_name, max_abs_diff, mean_abs_diff, match_ratio)
if "linear" in module_name and tvm.get_global_func("tvm.contrib.cublas.matmul", True):
runtime = get_tvm_runtime(script_module, input_name, ishape, target="cuda -libs=cublas")
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
cublas_result = runtime.get_output(0).numpy()
# It is generally safe to enable this assertion, but disabled for CI
# tvm.testing.assert_allclose(cublas_result, pt_result, atol=1e-5, rtol=1e-5)
print(np.max(np.abs(cublas_result - pt_result)))
# sample outputs
"""
relu 0.0039215684 2.6052087e-08 0.9999933567176871
leaky_relu 0.0 0.0 1.0
upsample bilinear 0.0 0.0 1.0
conv_bn 0.22062653 0.011478779 0.6909348115006899
conv_bn_relu 0.3700896 0.010921672 0.7489366477964451
linear 0.15987062 0.009231662 0.794921875
linear_relu 0.14180502 0.0053220326 0.8828125
conv_transpose 0.0033792555 4.4658788e-07 0.9998678439971806
conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019
conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732
linear, per_channel 0.0 0.0 1.0
linear_relu, per_channel 0.0 0.0 1.0
hsigmoid 0.002614379 0.00020525524 0.9214896896258503
hswish 0.0026143193 1.7367661e-08 0.9999933567176871
hswish, per_channel 0.0 0.0 1.0
semodule, per_channel 0.0039885044 0.0008620687 0.7838592529296875
mul_scalar negative 0.0011764616 7.815566e-09 0.9999933567176871
"""
# we cannot make any guarantee on how close the raw output is to torch
# tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-1, atol=1e-1)
def test_quantized_imagenet():
def get_transform():
import torchvision.transforms as transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return transforms.Compose(
[transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]
)
def get_real_image(im_height, im_width):
repo_base = "https://github.com/dmlc/web-data/raw/main/tensorflow/models/InceptionV1/"
img_name = "elephant-299.jpg"
image_url = os.path.join(repo_base, img_name)
img_path = download_testdata(image_url, img_name, module="data")
return Image.open(img_path).resize((im_height, im_width))
def get_imagenet_input():
im = get_real_image(224, 224)
preprocess = get_transform()
pt_tensor = preprocess(im)
return np.expand_dims(pt_tensor.numpy(), 0)
from torchvision.models.quantization import googlenet as qgooglenet
from torchvision.models.quantization import inception as qinception
from torchvision.models.quantization import mobilenet as qmobilenet
from torchvision.models.quantization import (
mobilenet_v3_large as qmobilenet_v3_large,
)
from torchvision.models.quantization import resnet as qresnet
per_channel = True
qmodels = [
("resnet18", qresnet.resnet18(pretrained=True), per_channel),
("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel),
("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
# tracing quantized googlenet broken as of v1.6
# ("googlenet", qgooglenet(pretrained=True), per_channel),
# As of v1.10, quantized mobilenet v3 has a weird segfault issue
# during make_conv_packed_param
# See https://ci.tlcpack.ai/blue/organizations/jenkins/tvm/detail/ci-docker-staging/192
# ("mobilenet_v3_large", qmobilenet_v3_large(pretrained=True, quantize=True).eval(), True)
]
results = []
for (model_name, raw_model, per_channel) in qmodels:
raw_model.eval()
if per_channel:
model_name += ", per channel quantization"
else:
model_name += ", per tensor quantization"
inp = get_imagenet_input()
pt_inp = torch.from_numpy(inp)
if "mobilenet_v3_large" not in model_name:
# mv3 was qat-ed, quantize=True option above makes it already quantized
quantize_model(raw_model, pt_inp, per_channel=per_channel)
script_module = torch.jit.trace(raw_model, pt_inp).eval()
with torch.no_grad():
pt_result = script_module(pt_inp).numpy()
input_name = "image"
runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224))
runtime.set_input(input_name, inp)
runtime.run()
tvm_result = runtime.get_output(0).numpy()
results.append((model_name, pt_result[0], tvm_result[0]))
for (model_name, pt_result, tvm_result) in results:
max_abs_diff = np.max(np.abs(tvm_result - pt_result))
mean_abs_diff = np.mean(np.abs(tvm_result - pt_result))
num_identical = np.sum(tvm_result == pt_result)
pt_top3_labels = np.argsort(pt_result)[::-1][:3]
tvm_top3_labels = np.argsort(tvm_result)[::-1][:3]
print("\nModel name: %s" % model_name)
print("PyTorch top3 label:", pt_top3_labels)
print("TVM top3 label:", tvm_top3_labels)
print("max abs diff:", max_abs_diff)
print("mean abs_diff:", mean_abs_diff)
print("%d in 1000 raw outputs identical." % num_identical)
assert set(pt_top3_labels) == set(tvm_top3_labels)
# sample outputs
"""
Model name: resnet18, per tensor quantization
PyTorch top3 label: [386 101 385]
TVM top3 label: [386 101 385]
max abs diff: 0.65681696
mean abs_diff: 0.14055882
236 in 1000 raw outputs identical.
Model name: mobilenet_v2, per tensor quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 2.1262953
mean abs_diff: 0.41025686
101 in 1000 raw outputs identical.
Model name: inception_v3, per tensor quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.9994669
mean abs_diff: 0.098697364
272 in 1000 raw outputs identical.
Model name: googlenet, per tensor quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.28248847
mean abs_diff: 0.0634469
274 in 1000 raw outputs identical.
Model name: resnet18, per channel quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.65908074
mean abs_diff: 0.1274223
469 in 1000 raw outputs identical.
Model name: mobilenet_v2, per channel quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.71120834
mean abs_diff: 0.15883648
423 in 1000 raw outputs identical.
Model name: inception_v3, per channel quantization
PyTorch top3 label: [386 101 385]
TVM top3 label: [386 101 385]
max abs diff: 1.3372154
mean abs_diff: 0.1225224
401 in 1000 raw outputs identical.
Model name: googlenet, per channel quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.34015465
mean abs_diff: 0.054197952
558 in 1000 raw outputs identical.
"""
def test_serialized_modules():
ishape = (1, 16, 64, 64)
raw_module = AdaptiveAvgPool2d().eval()
inp = torch.rand(ishape)
quantize_model(raw_module, inp)
script_module = torch.jit.trace(raw_module, inp).eval()
fname = "tmp.pt"
torch.jit.save(script_module, fname)
loaded = torch.jit.load(fname)
os.remove(fname)
with torch.no_grad():
pt_result = loaded(inp.clone()).numpy()
input_name = "input"
runtime = get_tvm_runtime(loaded, input_name, ishape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
tvm_result = runtime.get_output(0).numpy()
# with 0.5ish results, 1e-2 is relative accuracy close to 2**-6.
# for simple layers like here this should be achievable
# with 8 bit quantization
# we only require 90% match just to be sure
num_identical = np.sum(np.abs(tvm_result - pt_result) < 1e-2)
match_ratio = num_identical / float(np.prod(tvm_result.shape))
assert match_ratio > 0.90
def test_quantize_dynamic():
# A wrapper is required for quantize_dynamic to work correctly
class LinearWrapper(nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.linear = nn.Linear(in_dim, hidden_dim)
def forward(self, inp):
return self.linear(inp)
torch.manual_seed(0)
mod = LinearWrapper(16, 32)
for qconfig in [
torch.quantization.per_channel_dynamic_qconfig,
torch.quantization.default_dynamic_qconfig,
]:
for ishape in [(16, 16), (10, 16, 16)]:
qspec = {nn.Linear: qconfig}
qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8)
inp = torch.randn(*ishape)
script_module = torch.jit.trace(qmod, inp).eval()
with torch.no_grad():
pt_result = script_module(inp.clone()).numpy()
input_name = "input"
runtime = get_tvm_runtime(script_module, "input", inp.shape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
tvm_result = runtime.get_output(0).numpy()
# Only compare with the PyTorch result for version v1.6 or newer
# Have seen a strange accuracy problem from PyTorch 1.4 and 1.5
# Even with the manual random seed set, the same PyTorch
# version can outputs slightly different results depending on an environment.
# Outputs from v1.6 seem reliable. TVM's outputs are always the same
if is_version_greater_than("1.5.1"):
tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4)
def make_qnn_add_pattern():
from tvm.relay.dataflow_pattern import is_op, wildcard
lhs = wildcard()
rhs = wildcard()
lhs_scale = wildcard()
lhs_zero_point = wildcard()
rhs_scale = wildcard()
rhs_zero_point = wildcard()
output_scale = wildcard()
output_zero_point = wildcard()
qadd = is_op("qnn.add")(
lhs,
rhs,
lhs_scale,
lhs_zero_point,
rhs_scale,
rhs_zero_point,
output_scale,
output_zero_point,
)
return qadd.optional(is_op("clip"))
@register_pattern_table("test_table")
def pattern_table():
return [
("qnn_add", make_qnn_add_pattern()),
]
def run_qnn_mergecomposite(script_module, input_name, ishape):
input_shapes = [(input_name, ishape)]
with tvm.testing.disable_span_filling():
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
with tvm.testing.enable_span_filling():
mod_with_span, _ = relay.frontend.from_pytorch(script_module, input_shapes)
assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
pattern_table = get_pattern_table("test_table")
with tvm.transform.PassContext(opt_level=3):
pass_list = [
tvm.relay.transform.SimplifyInference(),
tvm.relay.transform.MergeComposite(pattern_table),
]
composite_partition = tvm.transform.Sequential(pass_list)
partitioned = composite_partition(mod)
def test_qnn_mergecomposite():
from torchvision.models.quantization import resnet as qresnet
model = qresnet.resnet18(pretrained=True)
model.eval()
inp = torch.zeros((1, 3, 224, 224))
model.fuse_model()
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
torch.quantization.prepare(model, inplace=True)
model(inp)
torch.quantization.convert(model, inplace=True)
script_module = torch.jit.trace(model, inp).eval()
input_name = "image"
run_qnn_mergecomposite(script_module, input_name, inp.shape)
def test_keep_quantized_weight():
qmodules = []
for per_channel in [False, True]:
qmodules += [
((1, 3, 224, 224), ConvBn(), per_channel),
((16, 16), Linear(), per_channel),
]
for (ishape, raw_module, per_channel) in qmodules:
raw_module.eval()
inp = torch.rand(ishape)
quantize_model(raw_module, inp, per_channel=per_channel)
script_module = torch.jit.trace(raw_module, inp).eval()
input_name = "input"
runtime = get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=False)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
tvm_result = runtime.get_output(0).numpy()
runtime_int8_weight = get_tvm_runtime(
script_module, input_name, ishape, keep_quantized_weight=True
)
runtime_int8_weight.set_input(input_name, inp.numpy().copy())
runtime_int8_weight.run()
tvm_result_int8_weight = runtime_int8_weight.get_output(0).numpy()
tvm.testing.assert_allclose(tvm_result, tvm_result_int8_weight)
def test_tuple_lowered():
# See the following discuss thread for details
# https://discuss.tvm.apache.org/t/bug-frontend-pytorch-relay-ir-is-inconsistent-with-that-of-the-original-model/12010
class ConvBnRelu(nn.Module):
def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1, bias=True, groups=1):
super(ConvBnRelu, self).__init__()
if groups > 1:
self.conv = nn.Conv2d(
inp, inp, kernel_size, stride, padding, bias=bias, groups=groups
)
self.bn = nn.BatchNorm2d(inp)
else:
self.conv = nn.Conv2d(
inp, oup, kernel_size, stride, padding, bias=bias, groups=groups
)
self.bn = nn.BatchNorm2d(oup)
self.relu = nn.ReLU(inplace=True)
def forward(self, inputs):
x = self.conv(inputs)
x = self.bn(x)
x = self.relu(x)
return x
def conv_bn(inp, oup, stride=1, width_multiplier=1):
return ConvBnRelu(inp, oup, kernel_size=3, stride=stride, padding=1, bias=False)
def conv_dw(inp, oup, stride, width_multiplier=1, padding=1):
dw_block = nn.Sequential()
depth_wise = ConvBnRelu(
inp, oup, kernel_size=3, stride=stride, padding=padding, bias=False, groups=inp
)
point_wise = ConvBnRelu(inp, oup, kernel_size=1, stride=1, padding=0, bias=False)
dw_block.add_module("depth_wise", depth_wise)
dw_block.add_module("point_wise", point_wise)
return dw_block
class Backbone(nn.Module):
def __init__(self, width_multiplier=1):
super(Backbone, self).__init__()
self.width_multiplier = width_multiplier
self.conv1 = conv_bn(3, 16, 2, self.width_multiplier)
self.conv2 = conv_dw(16, 32, 1, self.width_multiplier)
def forward(self, inputs):
x1 = self.conv1(inputs)
x2 = self.conv2(x1)
return [x1, x2]
class QuantizableBackbone(nn.Module):
def __init__(self, inputsize=(128, 128)):
super(QuantizableBackbone, self).__init__()
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.backbone = Backbone()
def fuse_model(self):
fuse_modules_qat = getattr(torch.ao.quantization, "fuse_modules_qat", fuse_modules)
for idx, m in enumerate(self.modules()):
if type(m) == ConvBnRelu:
fuse_modules_qat(m, ["conv", "bn", "relu"], inplace=True)
def forward(self, input):
input = self.quant(input)
y0, y1 = self.backbone(input)
y0 = self.dequant(y0)
y1 = self.dequant(y1)
return y0, y1
fp32_input = torch.randn(1, 3, 128, 128)
model = QuantizableBackbone()
model.train()
model.fuse_model()
model.qconfig = get_default_qat_qconfig("qnnpack")
prepare_qat(model, inplace=True)
model.eval()
model(fp32_input)
model_int8 = torch.quantization.convert(model, inplace=True)
script_module = torch.jit.trace(model_int8, fp32_input).eval()
input_infos = [("input", (fp32_input.shape, "float32"))]
with tvm.testing.disable_span_filling():
mod, _ = relay.frontend.from_pytorch(script_module, input_infos)
with tvm.testing.enable_span_filling():
mod_with_span, _ = relay.frontend.from_pytorch(script_module, input_infos)
assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
output = mod["main"].body
assert isinstance(output, relay.Tuple) and len(output) == 2
dq1, dq2 = output
assert dq1.op.name == "qnn.dequantize" and dq2.op.name == "qnn.dequantize"
scale1 = dq1.args[1].data.numpy().item()
scale2 = dq2.args[1].data.numpy().item()
assert scale1 != scale2