blob: 4b2d55ebe86491059c80eec46d784b01e59de69a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Find scales for quantization on the dataset."""
from __future__ import absolute_import
import logging
import multiprocessing as mp
import numpy as np
import tvm
import tvm.driver
from tvm.ir import IRModule
from . import _quantize
from . import quantize
from .. import op as _op
from .. import expr as _expr
from .. import analysis as _analysis
from .. import build_module as _build_module
from ...contrib import graph_executor
from .kl_divergence import _find_scale_by_kl
def _get_profile_runtime(mod):
func = mod["main"]
func = _quantize.CreateStatsCollector(func)
if tvm.target.Target.current():
target = tvm.target.Target.current()
dev = tvm.device(target.kind.name)
else:
target = "llvm"
dev = tvm.device(target)
with tvm.transform.PassContext(opt_level=3):
lib = _build_module.build(func, target=target)
runtime = graph_executor.GraphModule(lib["default"](dev))
return runtime
def collect_stats(mod, dataset, chunk_by=-1):
"""Given an annotated graph, create a profile graph to collect profile data from the
calibration dataset. This pass collects simulated_quantize op input into a tuple.
Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile
graph.
Parameters
----------
mod: Module
The simulation graph after annotation.
dataset: Iterable[NDArray]
The calibration dataset.
chunk_by: optional, int
The size of chunk to be returned in one iteration. It is meant to be
used for reducing memory usage. If not specified, return samples for
all layers in one chunk.
Returns
-------
ret: Iterable[list of ndarray]
List of output data of each layer, chunked by the chunk_by parameter
"""
logging.info("collecting statistics for calibration...")
runtime = _get_profile_runtime(mod)
num_outputs = runtime.get_num_outputs()
chunk_by = num_outputs if chunk_by == -1 else chunk_by
for i in range(0, num_outputs, chunk_by):
outputs = [[] for i in range(min(chunk_by, num_outputs - i))]
for batch in dataset:
runtime.set_input(**batch)
runtime.run()
for j in range(i, min(i + chunk_by, num_outputs)):
outputs[j - i].append(runtime.get_output(j).numpy())
yield [np.concatenate(output).reshape(-1) for output in outputs]
def _kl_scale(mod, dataset):
cfg = quantize.current_qconfig()
chunk_by = cfg.calibrate_chunk_by
scales = []
for samples in collect_stats(mod, dataset, chunk_by):
logging.info("finding threshold with kl for calibration...")
with mp.Pool() as pool:
scales += list(pool.map(_find_scale_by_kl, samples))
def func(_):
scale = scales[func.scale_idx]
func.scale_idx += 1
return scale
func.scale_idx = 0
return func
def _find_scale_by_percentile(arr, percentile=0.99999):
assert isinstance(arr, np.ndarray)
x = np.abs(arr)
max_k = int(x.size * percentile)
return np.partition(x, max_k)[max_k]
def _percentile_scale(mod, dataset):
cfg = quantize.current_qconfig()
chunk_by = cfg.calibrate_chunk_by
scales = []
for samples in collect_stats(mod, dataset, chunk_by):
logging.info("finding threshold with percentile for calibration...")
with mp.Pool() as pool:
scales += list(pool.map(_find_scale_by_percentile, samples))
def func(_):
scale = scales[func.scale_idx]
func.scale_idx += 1
return scale
func.scale_idx = 0
return func
def _set_params(mod, input_scale_func, weight_scale_func):
quantize_op = _op.get("relay.op.annotation.simulated_quantize")
cfg = quantize.current_qconfig()
const_params = {}
def visit_func(expr):
"""visitor function for traverse"""
if isinstance(expr, _expr.Call) and expr.op == quantize_op:
_, ndom_scale, nclip_min, nclip_max = expr.args
attrs = expr.attrs
kind = attrs.kind
nbit = cfg.get_nbit_by_kind(kind)
valid_bit = nbit - attrs.sign
# set scale
if kind == quantize.QAnnotateKind.WEIGHT:
assert isinstance(expr.args[0], _expr.Constant)
scale = weight_scale_func(expr)
else:
scale = input_scale_func(expr)
def _make_const(val):
return _expr.const(val, "float32")
valid_range = 2**valid_bit
const_params[ndom_scale] = _make_const(scale / valid_range)
const_params[nclip_min] = _make_const(-(valid_range - 1))
const_params[nclip_max] = _make_const((valid_range - 1))
main_func = mod["main"]
_analysis.post_order_visit(main_func, visit_func)
main_func = _expr.bind(main_func, const_params)
func_dict = {}
for global_var, func in mod.functions.items():
if global_var.name_hint != "main":
func_dict[global_var] = func
return IRModule.from_expr(main_func, func_dict)
# weight scale functions
def _power2_scale(sq_call): # pylint: disable=unused-argument
"""calculate weight scale with nearest mode-2 scale"""
var = sq_call.args[0]
assert isinstance(var, _expr.Constant)
val = np.amax(np.abs(var.data.numpy()))
return 2 ** np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0
def _max_scale(sq_call):
"""calculate weight scale with maximum absolute value"""
var = sq_call.args[0]
assert isinstance(var, _expr.Constant)
val = np.amax(np.abs(var.data.numpy()))
return val
# input scale functions
def _global_scale(sq_call): # pylint: disable=unused-argument
cfg = quantize.current_qconfig()
return cfg.global_scale
def calibrate(dataset=None):
"""The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator.
Parameters
---------
dataset: Optional[Iterable[NDArray]]
The calibration dataset.
Returns
-------
ret: Function
The module pass function.
"""
def wrapped_func(mod, _):
"""make transform.module pass happy"""
cfg = quantize.current_qconfig()
if cfg.calibrate_mode == "kl_divergence":
input_scale_func = _kl_scale(mod, dataset)
elif cfg.calibrate_mode == "global_scale":
input_scale_func = _global_scale
elif cfg.calibrate_mode == "percentile":
input_scale_func = _percentile_scale(mod, dataset)
else:
raise ValueError("Unknown calibrate mode {}".format(cfg.calibrate_mode))
if cfg.weight_scale == "max":
weight_scale_func = _max_scale
elif cfg.weight_scale == "power2":
weight_scale_func = _power2_scale
else:
raise ValueError("Unknown weight scale mode {}".format(cfg.weight_scale))
return _set_params(mod, input_scale_func, weight_scale_func)
return wrapped_func