blob: 3e221c613e65074539a39a992f5fba0499d03cb7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import warnings
import collections
import mxnet as mx
from mxnet import amp
from mxnet.gluon import nn
from mxnet.operator import get_all_registered_operators_grouped
def test_amp_coverage(lp_dtype, lp_name):
conditional = [item[0] for item in amp.list_conditional_fp32_ops(lp_dtype)]
lp16_ops = amp.list_lp16_ops(lp_dtype)
lp16_fp32_ops = amp.list_lp16_fp32_ops(lp_dtype)
fp32_ops = amp.list_fp32_ops(lp_dtype)
widest_ops = amp.list_widest_type_cast(lp_dtype)
all_lp_lists = [lp16_ops, lp16_fp32_ops, fp32_ops, widest_ops, conditional]
# Check for duplicates
for op_list in all_lp_lists:
ret = [op for op, count in collections.Counter(op_list).items() if count > 1]
assert ret == [], "Elements " + str(ret) + " are duplicated in the AMP lists."
all_lp_ops = [op for op_list in all_lp_lists for op in op_list]
ret = [op for op, count in collections.Counter(all_lp_ops).items() if count > 1]
assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP list."
# Check the coverage
covered_ops = set(all_lp_ops)
all_mxnet_ops = get_all_registered_operators_grouped()
required_ops = {op for op in all_mxnet_ops if not "backward" in op}
extra_ops = covered_ops - required_ops
assert not extra_ops, f"{len(extra_ops)} operators are not needed in the AMP lists: {sorted(extra_ops)}"
guidelines = f"""Please follow these guidelines for choosing a proper list:
- if your operator is not to be used in a computational graph
(e.g. image manipulation operators, optimizers) or does not have
inputs, put it in {lp_name.upper()}_FP32_FUNCS list,
- if your operator requires FP32 inputs or is not safe to use with lower
precision, put it in FP32_FUNCS list,
- if your operator supports both FP32 and lower precision, has
multiple inputs and expects all inputs to be of the same
type, put it in WIDEST_TYPE_CASTS list,
- if your operator supports both FP32 and lower precision and has
either a single input or supports inputs of different type,
put it in {lp_name.upper()}_FP32_FUNCS list,
- if your operator is both safe to use in lower precision and
it is highly beneficial to use it in lower precision, then
put it in {lp_name.upper()}_FUNCS (this is unlikely for new operators)
- If you are not sure which list to choose, FP32_FUNCS is the
safest option"""
missing_ops = required_ops - covered_ops
if len(missing_ops) > 0:
warnings.warn(f"{len(missing_ops)} operators {sorted(missing_ops)} do not exist in AMP lists "
f"(in python/mxnet/amp/lists/symbol_{lp_name.lower()}.py) - please add them. \n{guidelines}")
def test_amp_basic_use(lp_dtype):
class TestNet(nn.HybridBlock):
def __init__(self):
super().__init__()
self.fc1 = nn.Dense(4)
self.fc2 = nn.Dense(4)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x.reshape((-1, 2, 2))
data_example = mx.np.random.uniform(-1, 1, (4, 4))
net = TestNet()
net.initialize()
net = amp.convert_hybrid_block(net, data_example, lp_dtype)
lp16_casts = 1 # cast for network input
lp16_casts += 2 # cast for weights and bias of `fc1`
lp16_casts += 2 # cast for weights and bias of `fc2`
other_casts = 1 # cast for the network output (from lp16 to f32)
lp16_tensors = 1 # cast network input
lp16_tensors += 3 # cast weights and bias of `fc1`, `fc1` output
lp16_tensors += 3 # cast weights and bias of `fc2`, `fc2` output
lp16_tensors += 1 # reshape output
check_amp_net_stats(lp_dtype, net, data_example, lp16_tensors_num=lp16_tensors, lp16_casts_num=lp16_casts,
other_casts_num=other_casts)
def test_amp_offline_casting(lp_dtype):
class TestNet(nn.HybridBlock):
def __init__(self):
super().__init__()
self.lp16_op1 = nn.Conv2D(4, 3)
self.lp16_op2 = nn.Conv2DTranspose(4, 3)
self.fp32_op = nn.Dense(4)
def forward(self, x):
x = self.lp16_op1(x)
x = self.lp16_op2(x)
x = x.reshape(x.shape[0], -1)
with nn.HybridBlock.OptConstraint.disable_amp():
x = self.fp32_op(x)
return x
net = TestNet()
net.initialize()
data_example = mx.np.random.uniform(-1, 1, (4, 3, 16, 16))
lp_net = amp.convert_hybrid_block(net, data_example, lp_dtype, cast_params_offline=True)
check_amp_net_stats(lp_dtype, lp_net, data_example, lp16_tensors_num=4,
lp16_casts_num=1, other_casts_num=1)
for name, data in lp_net.collect_params().items():
assert mx.nd.get_dtype_name(data.dtype) == ('float32' if 'fp32_op' in name else lp_dtype)
def test_amp_offline_casting_shared_params(lp_dtype):
COMMON_SIZE = 4
class TestNet(nn.HybridBlock):
def __init__(self):
super().__init__()
self.lp16_op1 = nn.Dense(COMMON_SIZE)
self.lp16_op2 = nn.Dense(COMMON_SIZE)
self.lp16_op2.share_parameters({'weight': self.lp16_op1.weight})
self.fp32_op = nn.Dense(COMMON_SIZE)
self.fp32_op.share_parameters({'bias': self.lp16_op2.bias})
def forward(self, x):
x = self.lp16_op1(x)
x1 = self.lp16_op2(x)
with nn.HybridBlock.OptConstraint.disable_amp():
x2 = self.fp32_op(x)
x = mx.np.concat((x1, x2), axis=1)
return x
net = TestNet()
net.initialize()
data_example = mx.np.random.uniform(-1, 1, (4, COMMON_SIZE))
lp_net = amp.convert_hybrid_block(net, data_example, lp_dtype, cast_params_offline=True)
check_amp_net_stats(lp_dtype, lp_net, data_example, lp16_tensors_num=4,
lp16_casts_num=2, other_casts_num=2)
for name, data in lp_net.collect_params().items():
assert mx.nd.get_dtype_name(data.dtype) == ('float32' if 'fp32_op' in name else lp_dtype)
def test_lp16_fp32_ops_order_independence(lp_dtype):
class TestNet(nn.HybridBlock):
def __init__(self, lp16_fp32_is_first):
super().__init__()
if lp16_fp32_is_first:
self.first = mx.npx.batch_flatten # lp16_fp32_op
self.second = nn.Dense(4)
else:
self.first = nn.Dense(4)
self.second = mx.npx.batch_flatten # lp16_fp32_op
def forward(self, x):
x = 2**x
x1 = self.first(x)
x2 = self.second(x)
return x1, x2
data_example = mx.np.random.uniform(-1, 1, (4, 16))
for lp16_fp32_is_second in [False, True]:
net = TestNet(lp16_fp32_is_second)
net.initialize()
net = amp.convert_hybrid_block(net, data_example, lp_dtype, cast_params_offline=True)
check_amp_net_stats(lp_dtype, net, data_example, lp16_tensors_num=3,
lp16_casts_num=1, other_casts_num=2)
def test_amp_node_excluding(lp_dtype):
DISABLE_AMP_ATTR_DICT = {'__opt_constraint__': str(
mx.gluon.HybridBlock.OptConstraint.Flag.DisableAMP.value)}
data = mx.sym.var('data')
wei = mx.sym.var('weights')
bias = mx.sym.var('bias')
# manually excluded
fc1 = mx.sym.FullyConnected(data, wei, bias, num_hidden=4, name='fc1', attr=DISABLE_AMP_ATTR_DICT)
# to be excluded using the conversion API
fc2 = mx.sym.FullyConnected(data, wei, bias, num_hidden=4, name='fc2')
symnet = mx.sym.Group([fc1, fc2])
net = mx.gluon.SymbolBlock(symnet, [data])
net.initialize()
# exclude only nodes with set attribute (only 1 node - `fc1`)
data_example = mx.np.random.uniform(-1, 1, (4, 16))
net_1_excluded = amp.convert_hybrid_block(net, data_example, lp_dtype)
lp16_tensors = 4 # cast `data`, weights and bias of `fc1`, `fc1` output
lp16_casts = 3 # `data` cast, casts for weights and bias of `fc1`
other_casts = 1 # cast for the network output (from lp16 to f32)
check_amp_net_stats(lp_dtype, net_1_excluded, data_example, lp16_tensors_num=lp16_tensors,
lp16_casts_num=lp16_casts, other_casts_num=other_casts)
# exclude using the `excluded_sym_names` argument (both nodes)
net_2_excluded = amp.convert_hybrid_block(net, data_example, lp_dtype,
excluded_sym_names=['fc1', 'fc2'])
check_amp_net_stats(lp_dtype, net_2_excluded, data_example, lp16_tensors_num=0,
lp16_casts_num=0, other_casts_num=0)
def check_amp_net_stats(lp_dtype, net, data_example, lp16_tensors_num, lp16_casts_num, other_casts_num):
lp16_tensors = set()
lp16_casts = set()
other_casts = set()
def inspect_output(tensor_name, op_name, tensor):
dtype = mx.nd.get_dtype_name(tensor.dtype)
if op_name == 'amp_cast':
if dtype == lp_dtype:
lp16_casts.add(tensor_name)
else:
other_casts.add(tensor_name)
if dtype == lp_dtype:
lp16_tensors.add(tensor_name)
net.register_op_hook(inspect_output)
net(data_example)
assert len(lp16_tensors) == lp16_tensors_num, f'Bad lp16 tensors! Present tensors: {sorted(lp16_tensors)}'
assert len(lp16_casts) == lp16_casts_num, f'Bad lp16 casts! Present casts: {sorted(lp16_casts)}'
assert len(other_casts) == other_casts_num, f'Bad casts! Present casts: {sorted(other_casts)}'