blob: 1709e4ca0f952023ec95b77a4bee5e2e97a1b58c [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
"""Functions for enabling AMP (automatic mixed precision)."""
__all__ = ['init', 'init_trainer', 'scale_loss', 'unscale', 'convert_model',
'convert_hybrid_block', 'list_lp16_ops', 'list_fp32_ops',
'list_lp16_fp32_ops', 'list_conditional_fp32_ops',
'list_widest_type_cast', 'list_loss_output_functions', 'list_lp16_use_fp32_params',
'convert_symbol']
from array import array
import ctypes
import inspect
import logging
import contextlib
import sys
import numpy as np
from mxnet import numpy
from .. import symbol
from ..device import gpu
from ..symbol import Symbol
from ..symbol import contrib as symbol_contrib
from .. import ndarray
from ..ndarray import NDArray, dtype_np_to_mx, get_dtype_type, get_dtype_name, bfloat16
from . import lists
from ..gluon import Block, HybridBlock, trainer
from .. import base
from ..base import (_NP_OP_PREFIX, _NP_OP_SUBMODULE_LIST, _NP_EXT_OP_PREFIX,
_NP_EXT_OP_SUBMODULE_LIST, _NP_INTERNAL_OP_PREFIX,
c_str_array, c_str, c_array_buf, SymbolHandle, check_call, _LIB)
from .. import optimizer as opt
from .loss_scaler import LossScaler
from ..operator import get_all_registered_operators_grouped
from ..util import wrap_ctx_to_device_func
OFFLINE_CAST_DTYPE_ATTR = '__amp_dtype__'
float_types_gpu = (np.float16, np.float32)
float_types_cpu = (bfloat16, np.float32)
def _cast_symbol_NDArray(s, dtype, is_numpy_module=False):
if isinstance(s, Symbol):
amp_cast = symbol.numpy._internal.amp_cast if is_numpy_module else symbol.amp_cast
return amp_cast(s, dtype=dtype)
if isinstance(s, NDArray):
amp_cast = ndarray.numpy._internal.amp_cast if is_numpy_module else ndarray.amp_cast
if s.dtype != dtype and (s.dtype in float_types_gpu and s.context.device_type != 'cpu' or
s.dtype in float_types_cpu and s.context.device_type == 'cpu'):
return amp_cast(s, dtype=dtype)
return s
def _get_nd_fun_to_wrap(name, module, submodule_dict):
module_internal = getattr(module, "_internal")
prefix = base._get_op_name_prefix(name)
if prefix:
if prefix != '_random_' or name.endswith('_like'):
func_name = name[len(prefix):]
cur_module = submodule_dict[prefix]
else:
func_name = name
cur_module = module_internal
elif name.startswith('_'):
func_name = name
cur_module = module_internal
else:
func_name = name
cur_module = module
return func_name, [cur_module]
def _get_np_fun_to_wrap(name, ns_prefix):
for pre, mod, subs in ((_NP_OP_PREFIX, 'numpy', _NP_OP_SUBMODULE_LIST),
(_NP_EXT_OP_PREFIX, 'numpy_extension', _NP_EXT_OP_SUBMODULE_LIST),
(_NP_INTERNAL_OP_PREFIX, 'numpy._internal', [])):
if name.startswith(pre):
nm = name[len(pre):]
for sub in subs:
if nm.startswith(sub):
func, modules = nm[len(sub):], [sys.modules[f'{ns_prefix}.{mod}.{sub[1:-1]}']]
break
else:
func, modules = nm, [sys.modules[f'{ns_prefix}.{mod}']]
break
else:
assert False, f'Unable to find target module for {name} in {ns_prefix}'
if name.startswith(_NP_INTERNAL_OP_PREFIX) and ns_prefix == 'mxnet.ndarray':
if hasattr(ndarray.numpy._api_internal, func):
modules.append(ndarray.numpy._api_internal)
return func, modules
def _wrap_module_functions(module, is_numpy_module, target_dtype, get_aliases, get_cond_aliases,
get_fun_to_wrap, target_precision_ops=None, conditional_fp32_ops=None,
fp32_ops=None):
nd_mod = ndarray.numpy._internal if is_numpy_module else ndarray
sy_mod = symbol.numpy._internal if is_numpy_module else symbol
def _ndarray_wrapper(f, target_dtype, fp32_param=None, cond_arg=None):
def _new_fun(*args, **kwargs):
if cond_arg is not None:
if (cond_arg[0] not in kwargs or
kwargs[cond_arg[0]] not in cond_arg[1]):
return f(*args, **kwargs)
if fp32_param:
new_args = []
for i, x in enumerate(args):
if fp32_param[i]:
new_args.append(x)
else:
new_args.append(_cast_symbol_NDArray(x, target_dtype, is_numpy_module))
else:
new_args = list(map(
lambda x: _cast_symbol_NDArray(x, target_dtype, is_numpy_module), args))
args = tuple(new_args)
if fp32_param:
new_kwargs = {}
for k, v in kwargs.items():
if k in fp32_param:
new_kwargs[k] = v
else:
new_kwargs[k] = _cast_symbol_NDArray(v, target_dtype, is_numpy_module)
kwargs = new_kwargs
else:
kwargs = {k: _cast_symbol_NDArray(v, target_dtype, is_numpy_module)
for k, v in kwargs.items()}
return f(*args, **kwargs)
_new_fun.__name__ = f.__name__
_new_fun.__module__ = f.__module__
_new_fun.__doc__ = f.__doc__
return _new_fun
def _symbol_wrapper(f, target_dtype, fp32_param=None, cond_arg=None):
def _new_fun(*args, **kwargs):
if cond_arg is not None:
if (cond_arg[0] not in kwargs or
kwargs[cond_arg[0]] not in cond_arg[1]):
return f(*args, **kwargs)
sym = f(*args, **kwargs)
inputs = sym.get_children()
aux = sym.list_auxiliary_states()
if fp32_param:
new_inputs = []
for i, x in enumerate(inputs):
if (x.name in aux) or fp32_param[i]:
new_inputs.append(x)
else:
new_inputs.append(_cast_symbol_NDArray(x, target_dtype, is_numpy_module))
inputs = new_inputs
else:
inputs = list(map(lambda x: _cast_symbol_NDArray(x, target_dtype, is_numpy_module)
if x.name not in aux else x, inputs))
atomic_sym = sym._gen_atomic_symbol()
wrapped_sym = atomic_sym(*inputs)
wrapped_sym._set_attr(name=sym.name)
return wrapped_sym
_new_fun.__name__ = f.__name__
_new_fun.__module__ = f.__module__
_new_fun.__doc__ = f.__doc__
return _new_fun
def _symbol_widest_wrapper(f):
def _new_fun(*args, **kwargs):
symbols = []
is_symbol = False
args = list(args)
for i, arg in enumerate(args):
if isinstance(arg, (Symbol, NDArray)):
symbols.append((args, i, arg))
is_symbol = is_symbol or isinstance(arg, Symbol)
for k, arg in kwargs.items():
if isinstance(arg, (Symbol, NDArray)):
symbols.append((kwargs, k, arg))
is_symbol = is_symbol or isinstance(arg, Symbol)
if not is_symbol:
# NDArray case
widest_type = target_dtype
for _, _, arg in symbols:
if isinstance(arg, NDArray):
if arg.dtype == np.float32:
widest_type = np.float32
for arr, index, arg in symbols:
if arg.dtype != widest_type and arg.dtype == target_dtype:
arr[index] = nd_mod.amp_cast(arg, dtype=widest_type)
else:
# Symbol case
sym_to_check = list(map(lambda x: x[2], symbols))
casted_syms = sy_mod.amp_multicast(*sym_to_check, num_outputs=len(sym_to_check))
symbols = list(map(lambda x_y: (x_y[0][0], x_y[0][1], x_y[1]),
zip(symbols, casted_syms)))
for arr, index, arg in symbols:
arr[index] = arg
return f(*args, **kwargs)
_new_fun.__name__ = f.__name__
_new_fun.__module__ = f.__module__
_new_fun.__doc__ = f.__doc__
return _new_fun
_wrapper = _symbol_wrapper if module in (symbol, Symbol, symbol_contrib) else _ndarray_wrapper
fp32_param_list = list_lp16_use_fp32_params(target_dtype)
wrap_list = target_precision_ops if target_precision_ops is not None \
else list_lp16_ops(target_dtype)
for fun_name in get_aliases(wrap_list):
fun_name, modules = get_fun_to_wrap(fun_name, module)
for cur_module in modules:
f_to_wrap = getattr(cur_module, fun_name)
fp32_param = fp32_param_list[fun_name] if (fp32_param_list and fun_name in fp32_param_list) else None
setattr(cur_module, fun_name, _wrapper(f_to_wrap, target_dtype, fp32_param=fp32_param))
if not is_numpy_module and cur_module == module:
setattr(module.op, fun_name, _wrapper(f_to_wrap, target_dtype, fp32_param=fp32_param))
wrap_list = fp32_ops if fp32_ops is not None else list_fp32_ops(target_dtype)
for fun_name in get_aliases(wrap_list):
fun_name, modules = get_fun_to_wrap(fun_name, module)
for cur_module in modules:
f_to_wrap = getattr(cur_module, fun_name)
setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32))
if not is_numpy_module and cur_module == module:
setattr(module.op, fun_name, _wrapper(f_to_wrap, np.float32))
wrap_list = conditional_fp32_ops if conditional_fp32_ops is not None \
else list_conditional_fp32_ops(target_dtype)
for fun_name, arg, arg_values in get_cond_aliases(wrap_list):
fun_name, modules = get_fun_to_wrap(fun_name, module)
for cur_module in modules:
f_to_wrap = getattr(cur_module, fun_name)
setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32, cond_arg=(arg, arg_values)))
if not is_numpy_module and cur_module == module:
setattr(module.op, fun_name, _wrapper(f_to_wrap, np.float32, cond_arg=(arg, arg_values)))
for fun_name in get_aliases(list_widest_type_cast(target_dtype)):
fun_name, modules = get_fun_to_wrap(fun_name, module)
for cur_module in modules:
f_to_wrap = getattr(cur_module, fun_name)
setattr(cur_module, fun_name, _symbol_widest_wrapper(f_to_wrap))
if not is_numpy_module and cur_module == module:
setattr(module.op, fun_name, _symbol_widest_wrapper(f_to_wrap))
def _wrap_loss_output_functions(module, ls, target_dtype):
if module == ndarray:
def _wrapper(f):
def _scaling_wrapper(*args, **kwargs):
if 'grad_scale' in kwargs:
kwargs['grad_scale'] = kwargs['grad_scale'] * ls.loss_scale
else:
kwargs['grad_scale'] = ls.loss_scale
return f(*args, **kwargs)
_scaling_wrapper.__name__ = f.__name__
_scaling_wrapper.__module__ = f.__module__
_scaling_wrapper.__doc__ = f.__doc__
return _scaling_wrapper
else:
def _wrapper(f):
def _warning_wrapper(*args, **kwargs):
logging.warning("%s does not support dynamic loss scaling "
"in symbolic and hybridized execution.", f.__name__)
return f(*args, **kwargs)
_warning_wrapper.__name__ = f.__name__
_warning_wrapper.__module__ = f.__module__
_warning_wrapper.__doc__ = f.__doc__
return _warning_wrapper
for fun_name in list_loss_output_functions(target_dtype):
try:
f_to_wrap = getattr(module, fun_name)
setattr(module, fun_name, _wrapper(f_to_wrap))
except AttributeError:
pass
_amp_initialized = False
_amp_loss_scale_initialized = False
_loss_scaler = None
@contextlib.contextmanager
def scale_loss(loss, optimizer_or_trainer):
assert optimizer_or_trainer._amp_loss_scaler is not None, \
'Loss scaler is not initialized, did you forget to call amp.init_trainer()?'
optimizer_or_trainer._scale = (optimizer_or_trainer._amp_original_scale /
optimizer_or_trainer._amp_loss_scaler.loss_scale)
if isinstance(loss, (list, tuple)):
yield [l * optimizer_or_trainer._amp_loss_scaler.loss_scale for l in loss]
else:
yield optimizer_or_trainer._amp_loss_scaler.loss_scale * loss
def warn_if_model_exists():
for f in inspect.stack():
for k, v in f.frame.f_locals.items():
if isinstance(v, Block):
logging.warning('Block %s created in [%s:%d] before AMP init.',
k, f.filename, f.lineno)
return
def init(target_dtype='float16', target_precision_ops=None,
conditional_fp32_ops=None, fp32_ops=None, layout_optimization=False):
"""Initialize AMP (automatic mixed precision).
This needs to be done before model creation.
Parameters
----------
target_dtype : {'float16', 'bfloat16'}
Target low precision type for AMP. Currently only float16 and bfloat16 are supported.
target_precision_ops : list of string
Override the list of functions casted to target_dtype. Entries in this list
are names of the functions casted to target_dtype.
conditional_fp32_ops : list of (string, string, list of string)
Override the list of functions conditionally casted to FP32. The format
of the list is (name of the function, name of the parameter, list of
values of the parameter that make the function be casted to FP32).
fp32_ops : list of string
Override the list of functions casted to FP32. Entries in this list
are names of the functions casted to FP32.
"""
global _amp_initialized
global _loss_scaler
if not _amp_initialized:
assert target_dtype in ['float16', np.float16, 'bfloat16', bfloat16], \
"AMP currently supports only float16 or bfloat16 as a target_dtype"
_amp_initialized = True
log_msg = "Using AMP"
if layout_optimization:
log_msg += "\n - layout optimization: enabled"
check_call(_LIB.MXSetOptimizeLayout(ctypes.c_bool(True)))
logging.info(log_msg)
if target_dtype == "bfloat16":
target_dtype = bfloat16
else:
target_dtype = np.dtype(target_dtype)
warn_if_model_exists()
ops = get_all_registered_operators_grouped()
get_aliases_nd = lambda l: [a for op in l for a in ops[op] if not base._is_np_op(a)]
get_aliases_np = lambda l: [a for op in l for a in ops[op] if base._is_np_op(a)]
get_aliases_np_pub = lambda l: [a for op in l for a in ops[op]
if a.startswith(('_np_', '_npx_'))]
get_cond_aliases_nd = lambda l: [(a, *rest) for op, *rest in l for a in ops[op]
if not base._is_np_op(a)]
get_cond_aliases_np = lambda l: [(a, *rest) for op, *rest in l for a in ops[op]
if base._is_np_op(a)]
get_cond_aliases_np_pub = lambda l: [(a, *rest) for op, *rest in l for a in ops[op]
if a.startswith(('_np_', '_npx_'))]
sy_submodules = {p:getattr(symbol, p[1:-1]) for p in base._OP_NAME_PREFIX_LIST}
get_sy_fun = lambda fun, mod: _get_nd_fun_to_wrap(fun, mod, sy_submodules)
nd_submodules = {p:getattr(ndarray, p[1:-1]) for p in base._OP_NAME_PREFIX_LIST}
get_nd_fun = lambda fun, mod: _get_nd_fun_to_wrap(fun, mod, nd_submodules)
get_np_sy_fun = lambda fun, mod: _get_np_fun_to_wrap(fun, "mxnet.symbol")
get_np_nd_fun = lambda fun, mod: _get_np_fun_to_wrap(fun, "mxnet.ndarray")
get_np_fun = lambda fun, mode: _get_np_fun_to_wrap(fun, "mxnet")
todo = [
(symbol, False, get_aliases_nd, get_cond_aliases_nd, get_sy_fun),
(ndarray, False, get_aliases_nd, get_cond_aliases_nd, get_nd_fun),
(symbol.numpy, True, get_aliases_np, get_cond_aliases_np, get_np_sy_fun),
(ndarray.numpy, True, get_aliases_np, get_cond_aliases_np, get_np_nd_fun),
(numpy, True, get_aliases_np_pub, get_cond_aliases_np_pub, get_np_fun),
]
_loss_scaler = LossScaler()
for module, is_numpy, get_aliases, get_cond_aliases, get_fun in todo:
_wrap_module_functions(module, is_numpy, target_dtype, get_aliases, get_cond_aliases,
get_fun, target_precision_ops, conditional_fp32_ops, fp32_ops)
_wrap_loss_output_functions(module, _loss_scaler, target_dtype)
def init_trainer(optimizer_or_trainer):
"""Initialize trainer or optimizer to work with AMP dynamic loss scaling.
Parameters
----------
optimizer_or_trainer : Optimizer or Trainer
MXNet Optimizer or Gluon trainer to initialize with AMP
"""
global _amp_loss_scale_initialized
global _amp_initialized
global _loss_scaler
assert _amp_initialized, "AMP not initialized, did you forget to call amp.init()?"
if not _amp_loss_scale_initialized:
_amp_loss_scale_initialized = True
loss_scaler = _loss_scaler
else:
loss_scaler = LossScaler()
#_wrap_output
if isinstance(optimizer_or_trainer, trainer.Trainer):
optimizer_or_trainer._amp_loss_scaler = loss_scaler
optimizer_or_trainer._amp_original_scale = optimizer_or_trainer._scale
trainer.Trainer.amp_loss_scale = property(lambda self: self._amp_loss_scaler.loss_scale)
elif isinstance(optimizer_or_trainer, opt.Optimizer):
raise TypeError("AMP is currently only compatible with Gluon Trainer")
else:
raise TypeError("optimizer_or_trainer should be a Gluon Trainer or "
f"an optimizer, instead is {type(optimizer_or_trainer)}")
def unscale(optimizer_or_trainer):
"""Check and unscale the gradients manually. This function should only be used
if accessing gradients is necessary, e.g. for gradient clipping.
Parameters
----------
optimizer_or_trainer : Optimizer or Trainer
MXNet optimizer or Gluon Trainer used when scaling the gradients
"""
if isinstance(optimizer_or_trainer, trainer.Trainer):
valid_grads = [p._grad for p in optimizer_or_trainer._params if p._grad is not None]
for grads in valid_grads:
# TODO(ptredak): make a bulked unscale
for g in grads:
g[:] *= optimizer_or_trainer._scale
optimizer_or_trainer._scale = 1.
elif isinstance(optimizer_or_trainer, opt.Optimizer):
# TODO(ptredak): make it work with the optimizer
raise TypeError("AMP is currently only compatible with Gluon Trainer")
else:
raise TypeError("optimizer_or_trainer should be a Gluon Trainer or "
f"an optimizer, instead is {type(optimizer_or_trainer)}")
def convert_symbol(sym, input_dtypes, param_dtypes, target_dtype, target_dtype_ops=None,
fp32_ops=None, conditional_fp32_ops=None, excluded_sym_names=[],
cast_params_offline=False):
"""Given a symbol object representing a neural network of data type FP32 and target_dtype,
add cast layers according to the op lists (target_dtype_ops, fp32_ops,
conditional_fp32_ops) if provided, otherwise use the default
lists provided by the framework.
Parameters
----------
sym : Symbol
FP32 neural network symbol
input_dtypes: dict
Dictionary mapping names of model inputs to their dtypes
param_dtypes: dict
Dictionary mapping names of model parameters to their dtypes
target_dtype : str or numpy, optional defaults to float16
currently only supports float16 and bfloat16. The target dtype indicates to add cast layers
when possible so that lower precision computation can be leveraged.
target_dtype_ops : list of strs, optional
Override the list of operator names casted to the target_dtype.
If None, uses the framework's default list to be casted to target_dtype.
fp32_ops : list of strs, optional
Override the list of operator names casted to FP32.
If None, uses the framework's default list to be casted to FP32.
conditional_fp32_ops : list of (string, string, list of string), optional
Override the list of functions to be casted to FP32.
The format of the list is
(name of the function, name of the parameter,
list of values of the parameter that make the operator to be casted to FP32)
excluded_sym_names : list of strs, optional
A list of strings that represent the names of symbols that users want to exclude
from being casted to LP16 or FP32.
data_names : list of strs, optional
A list of strings that represent input data tensor names to the model
cast_params_offline : bool, default False
Whether to cast arg_params and aux_params now, instead of doing it every time at runtime.
"""
import json
assert isinstance(sym, Symbol), "First argument to convert_symbol should be a Symbol"
assert target_dtype_ops is None or isinstance(target_dtype_ops, list), \
"target_dtype_ops should be a list of strings"
assert fp32_ops is None or isinstance(fp32_ops, list), \
"fp32_ops should be a list of strings"
assert conditional_fp32_ops is None or isinstance(conditional_fp32_ops, list), \
"conditional_fp32_ops should be a list of strings"
target_dtype = get_dtype_name(target_dtype)
assert target_dtype in ['float16', *bfloat16.names], \
"Only float16 and bfloat16 types are currently supported as target_dtype"
if target_dtype_ops is None:
target_dtype_ops = list_lp16_ops(target_dtype)
if fp32_ops is None:
fp32_ops = list_fp32_ops(target_dtype)
# conditional ops
if conditional_fp32_ops is None:
conditional_fp32_ops = list_conditional_fp32_ops(target_dtype)
cond_ops = {cond_op[0]: {} for cond_op in conditional_fp32_ops}
for cond_op in conditional_fp32_ops:
op_name, attr_name, attr_vals = cond_op
assert isinstance(op_name, str) and isinstance(attr_name, str) and isinstance(attr_vals, list), \
"conditional_fp32_ops should be a list of (str, str, list of str)"
cond_ops[op_name].setdefault(attr_name, []).extend(attr_vals)
nodes_attrs = sym.attr_dict()
nodes_op = {n['name']: n['op'] for n in json.loads(sym.tojson())['nodes']}
for node_name, node_op in nodes_op.items():
if node_op not in cond_ops:
continue
node_attrs = nodes_attrs[node_name]
for attr_name, attr_vals in cond_ops[node_op].items():
assert attr_name in node_attrs
if node_attrs[attr_name] in attr_vals:
excluded_sym_names.append(node_name)
break
excluded_sym_names = set(excluded_sym_names)
for node in sym.get_internals():
if node.name in excluded_sym_names:
excluded_sym_names.remove(node.name)
opt_constraints = node.attr('__opt_constraint__')
opt_constraints = 0 if opt_constraints is None else int(opt_constraints)
opt_constraints |= HybridBlock.OptConstraint.Flag.DisableAMP.value
node._set_attr(__opt_constraint__=str(opt_constraints))
if len(excluded_sym_names) > 0:
logging.warning("excluded_sym_names are not present in the network. Missing nodes: {}".format(
excluded_sym_names))
# Op lists should not intersect
common_ops = set(target_dtype_ops) & set(fp32_ops)
assert len(common_ops) == 0, "Common ops in target_dtype_ops and fp32_ops: {}".format(common_ops)
common_ops = set(target_dtype_ops) & set(cond_ops)
assert len(common_ops) == 0, "Common ops in target_dtype_ops and conditional_fp32_ops: {}".format(
common_ops)
common_ops = set(cond_ops) & set(fp32_ops)
assert len(common_ops) == 0, "Common ops in fp32_ops and conditional_fp32_ops: {}".format(common_ops)
combined_ops = set(target_dtype_ops + fp32_ops + list(cond_ops.keys()))
original_cond_ops = [cond_op[0] for cond_op in list_conditional_fp32_ops(target_dtype)]
all_lp16_fp32_ops = set(list_lp16_ops(target_dtype) + list_fp32_ops(target_dtype) +
list_lp16_fp32_ops(target_dtype) + original_cond_ops)
illegal_ops = combined_ops - all_lp16_fp32_ops
assert len(illegal_ops) == 0, f'''Can only choose ops from one of the four lists
for lp16_ops and fp32_ops
1. amp.list_lp16_ops(target_dtype)
2. amp.list_fp32_ops(target_dtype)
3. amp.list_lp16_fp32_ops(target_dtype)
4. amp.list_conditional_fp32_ops(target_dtype)
Op {illegal_ops} not in any of them'''
widest_dtype_ops = list_widest_type_cast(target_dtype)
input_names = list(input_dtypes.keys())
all_arg_names, all_arg_types = [], []
for name, dtype in {**input_dtypes, **param_dtypes}.items():
all_arg_names.append(name)
all_arg_types.append(dtype_np_to_mx(dtype))
out = SymbolHandle()
check_call(_LIB.MXReducePrecisionSymbol(sym.handle,
ctypes.byref(out),
ctypes.c_int(dtype_np_to_mx(target_dtype)),
ctypes.c_int(cast_params_offline),
c_str(OFFLINE_CAST_DTYPE_ATTR),
ctypes.c_uint(len(input_names)),
c_str_array(input_names),
ctypes.c_uint(len(all_arg_names)),
c_str_array(all_arg_names),
c_array_buf(ctypes.c_int, array('i', all_arg_types)),
ctypes.c_uint(len(target_dtype_ops)),
c_str_array(target_dtype_ops),
ctypes.c_uint(len(fp32_ops)),
c_str_array(fp32_ops),
ctypes.c_uint(len(widest_dtype_ops)),
c_str_array(widest_dtype_ops)))
return type(sym)(out)
def convert_model(sym, arg_params, aux_params, input_dtypes, target_dtype,
target_dtype_ops=None, fp32_ops=None, conditional_fp32_ops=None,
excluded_sym_names=[], cast_params_offline=False):
"""API for converting a model from FP32 model to a mixed precision model.
MXNet tries to convert the FP32 model to mixed precision model by adding
cast layers using amp_cast and amp_multicast operators which can be used for inference use cases.
The decision on which cast layer to add is based on hardcoded lists for Automatic Mixed Precision
in MXNet. These lists can be overridden by the user by providing their own lists
using : targe_precision_ops, fp32_ops, widest_precision_ops, conditional_fp32_ops
arg_params : dict
Dictionary of name to `NDArray`.
aux_params : dict
Dictionary of name to `NDArray`.
input_dtypes: dict
Dictionary mapping names of model inputs to their dtypes
target_dtype : str
Currently only supports float16 and bfloat 16. The target dtype indicates to add cast layers
when possible so that lower precision computation can be leveraged.
target_dtype_ops : list of strs
Override the list of operator names casted to target_dtype.
If None, uses the framework's default list to be casted to target dtype.
fp32_ops : list of strs
Override the lists of operator names casted to FP32.
If None, uses the framework's default list to be casted to FP32.
widest_dtype_ops : list of strs
A list of op names provided by user which should run in widest precision among its inputs.
If None, uses the framework's default list of widest_precision_ops.
conditional_fp32_ops : list of (string, string, list of string)
Override the list of operators to be casted to FP32.
The format of the list is
(name of the function, name of the parameter,
list of values of the parameter that make the operator to be casted to
fp32)
excluded_sym_names : list of strs
A list of strings that represent the names of symbols that users want to exclude
from being executed in lower precision.
cast_params_offline : bool, default False
Whether to cast arg_params and aux_params now, instead of doing it every time at runtime.
"""
assert isinstance(sym, Symbol), "First argument to convert_model should be a Symbol"
assert isinstance(
arg_params, dict), "Second argument to convert_model should be a dict of name to ndarray"
assert isinstance(
aux_params, dict), "Third argument to convert_model should be a dict of name to ndarray"
arg_params = arg_params.copy()
aux_params = aux_params.copy()
param_dtypes = {name: data.dtype for name, data in arg_params.items()}
param_dtypes.update({name: data.dtype for name, data in aux_params.items()})
sym = convert_symbol(sym, input_dtypes, param_dtypes, target_dtype, target_dtype_ops,
fp32_ops, conditional_fp32_ops, excluded_sym_names, cast_params_offline)
# If dtype is set for params, cast the param to that dtype
attr_dict = sym.attr_dict()
for sym_name in sym.list_arguments():
if attr_dict.get(sym_name, {}).get(OFFLINE_CAST_DTYPE_ATTR, '') != '' and sym_name in arg_params:
typ = get_dtype_type(attr_dict[sym_name][OFFLINE_CAST_DTYPE_ATTR])
if arg_params[sym_name].dtype != typ:
arg_params[sym_name] = arg_params[sym_name].astype(typ)
for sym_name in sym.list_auxiliary_states():
if attr_dict.get(sym_name, {}).get(OFFLINE_CAST_DTYPE_ATTR, '') != '' and sym_name in aux_params:
typ = get_dtype_type(attr_dict[sym_name][OFFLINE_CAST_DTYPE_ATTR])
if aux_params[sym_name].dtype != typ:
aux_params[sym_name] = aux_params[sym_name].astype(typ)
# Return the converted symbol and casted params
return sym, arg_params, aux_params
@wrap_ctx_to_device_func
def convert_hybrid_block(block, data_example, target_dtype, target_dtype_ops=None,
fp32_ops=None, conditional_fp32_ops=None,
excluded_sym_names=[], device=None,
cast_params_offline=False):
"""Given a hybrid block/symbol block representing a FP32 model and a target_dtype,
return a block with mixed precision support which can be used for inference use cases.
Parameters
----------
block : HybridBlock or SymbolBlock object
FP32 HybridBlock or SymbolBlock object
data_example: tuple or list of NDArrays
Data example, representing the data that this model will work with during the inference.
target_dtype : str or numpy
currently only supports float16 and bfloat16. The target dtype indicates to add cast layers
when possible so that lower precision computation can be leveraged.
target_precision_ops : list of strs
Override the list of operator names casted to target_dtype.
If None, uses the framework's default list to be casted to FP32.
conditional_fp32_ops : list of (str, str, list of str)
Override the list of functions to be casted to FP32.
The format of the list is
(name of the function, name of the parameter,
list of values of the parameter that make the operator to be casted to FP32
excluded_sym_names : list of strs
A list of strings that represent the names of symbols that users want to exclude
from being quantized
device : Device
Device on which model parameters should live. Default value: current device.
cast_params_offline : bool, default False
Whether to cast arg_params and aux_params now, instead of doing it every time at runtime.
"""
from ..gluon import SymbolBlock
from ..ndarray import NDArray as ND_NDArray, waitall
from ..numpy import ndarray as NP_NDArray
assert isinstance(block, HybridBlock), "block input should be a HybridBlock"
if not isinstance(data_example, (list, tuple)):
data_example = [data_example]
for data in data_example:
assert isinstance(data, (ND_NDArray, NP_NDArray)), "Data example must be composed of " \
"mxnet.numpy.ndarray or mxnet.ndarray.NDArray instances"
if not block._active:
block.hybridize(static_alloc=False, static_shape=False)
block(*data_example)
waitall()
sym, params = block.export(None, remove_amp_cast=False)
args, auxs = {}, {}
for name, data in params.items():
if name.startswith('arg:'):
arg_name = name[len('arg:'):]
args[arg_name] = data
else:
assert name.startswith('aux:')
aux_name = name[len('aux:'):]
auxs[aux_name] = data
input_names = set(sym.list_arguments()) - (set(args.keys()) | set(auxs.keys()))
input_names_ordered = HybridBlock.generate_arg_names(len(data_example))
assert input_names == set(input_names_ordered)
input_dtypes = {name: data.dtype for name, data in zip(input_names_ordered, data_example)}
lp_sym, lp_args, lp_auxs = convert_model(sym, args, auxs, input_dtypes, target_dtype,
target_dtype_ops, fp32_ops, conditional_fp32_ops,
excluded_sym_names, cast_params_offline)
inputs = [in_sym for in_sym in lp_sym.get_inputs() if in_sym.name in input_names]
param_dict = lp_args
param_dict.update(lp_auxs)
ret = SymbolBlock(lp_sym, inputs)
ret.load_dict(param_dict, device=device, cast_dtype=True, dtype_source='saved')
return ret
def list_lp16_ops(target_dtype):
"""Get the default list of LP16 ops for AMP
"""
if target_dtype in ['float16', np.float16]:
return lists.symbol_fp16.FP16_FUNCS
else:
assert get_dtype_name(target_dtype) in bfloat16.names, "not supported type"
return lists.symbol_bf16.BF16_FUNCS
def list_fp32_ops(target_dtype):
"""Get the default list of FP32 ops for AMP
"""
if target_dtype in ['float16', np.float16]:
return lists.symbol_fp16.FP32_FUNCS
else:
assert get_dtype_name(target_dtype) in bfloat16.names, "not supported type"
return lists.symbol_bf16.FP32_FUNCS
def list_lp16_fp32_ops(target_dtype):
"""Get the default list of ops which run in both LP16 and FP32
"""
if target_dtype in ['float16', np.float16]:
return lists.symbol_fp16.FP16_FP32_FUNCS
else:
assert get_dtype_name(target_dtype) in bfloat16.names, "not supported type"
return lists.symbol_bf16.BF16_FP32_FUNCS
def list_conditional_fp32_ops(target_dtype):
"""Get the conditional fp32 ops list
"""
if target_dtype in ['float16', np.float16]:
return lists.symbol_fp16.CONDITIONAL_FP32_FUNCS
else:
assert get_dtype_name(target_dtype) in bfloat16.names, "not supported type"
return lists.symbol_bf16.CONDITIONAL_FP32_FUNCS
def list_widest_type_cast(target_dtype):
"""Get the widest type cast ops list
"""
if target_dtype in ['float16', np.float16]:
return lists.symbol_fp16.WIDEST_TYPE_CASTS
else:
assert get_dtype_name(target_dtype) in bfloat16.names, "not supported type"
return lists.symbol_bf16.WIDEST_TYPE_CASTS
def list_loss_output_functions(target_dtype):
"""Get loss function list
"""
if target_dtype in ['float16', np.float16]:
return lists.symbol_fp16.LOSS_OUTPUT_FUNCTIONS
else:
assert get_dtype_name(target_dtype) in bfloat16.names, "not supported type"
return lists.symbol_bf16.LOSS_OUTPUT_FUNCTIONS
def list_lp16_use_fp32_params(target_dtype):
""" Get the params restrict for LP16
"""
if target_dtype in ['float16', np.float16]:
return None
else:
assert get_dtype_name(target_dtype) in bfloat16.names, "not supported type"
return lists.symbol_bf16.BF16_USE_FP32_PARAMS