blob: df0e2f7159ab5b6a8d5c95460b44cde27c6b95ab [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
from pathlib import Path
curr_path = Path(__file__).resolve().parent
sys.path.insert(0, str(curr_path.parent))
import pytest
import mxnet as mx
import amp.common as amp_common_tests
from mxnet.test_utils import assert_almost_equal
from mxnet.amp.lists.symbol_bf16 import (BF16_FUNCS, BF16_FP32_FUNCS, WIDEST_TYPE_CASTS,
CONDITIONAL_FP32_FUNCS)
from op_cfg import get_op_cfg_generator, get_symblock_from_args_scenario, CFG_RTOL_ATOL
ALL_BF16_OPS = BF16_FUNCS + BF16_FP32_FUNCS + WIDEST_TYPE_CASTS
ALL_BF16_OPS += [op_name for op_name, attr_name, attr_vals in CONDITIONAL_FP32_FUNCS]
AMP_DTYPE = 'bfloat16'
def test_bf16_coverage():
amp_common_tests.test_amp_coverage(AMP_DTYPE, 'BF16')
@mx.util.use_np
def test_bf16_basic_use():
amp_common_tests.test_amp_basic_use(AMP_DTYPE)
@mx.util.use_np
def test_bf16_offline_casting():
amp_common_tests.test_amp_offline_casting(AMP_DTYPE)
@mx.util.use_np
def test_bf16_offline_casting_shared_params():
amp_common_tests.test_amp_offline_casting_shared_params(AMP_DTYPE)
@mx.util.use_np
def test_bf16_fp32_ops_order_independence():
amp_common_tests.test_lp16_fp32_ops_order_independence(AMP_DTYPE)
@mx.util.use_np
def test_bf16_test_node_excluding():
amp_common_tests.test_amp_node_excluding(AMP_DTYPE)
def get_param_name(param):
if isinstance(param, (mx.nd.NDArray, mx.np.ndarray)):
return 'Tensor' + str(param.shape)
if isinstance(param, (tuple, list)):
return str(type(param)(get_param_name(elem) for elem in param))
return str(param)
def get_test_name(param):
if isinstance(param, str):
return f'"{param}" ' # op_name
if isinstance(param, dict):
elements = []
for args_names, args_cfgs in param.items():
if isinstance(args_cfgs, tuple):
binded_args = args_names.split(',')
for arg_name, arg_val in zip(binded_args, args_cfgs):
elements.append(f'"{arg_name}": {get_param_name(arg_val)}')
else:
arg_name, arg_val = args_names, args_cfgs
elements.append(f'"{arg_name}": {get_param_name(arg_val)}')
return ' ' + ', '.join(elements)
raise TypeError('Op configuration should only consist of its name (str) and arg config (dict)')
@pytest.mark.parametrize(argnames=('op_name', 'args_scenario'),
argvalues=get_op_cfg_generator(ALL_BF16_OPS, AMP_DTYPE),
ids=get_test_name)
def test_bf16_op(op_name, args_scenario):
symblock, bf16_symblock_input_data = get_symblock_from_args_scenario(op_name, args_scenario)
rtol, atol = args_scenario.get(CFG_RTOL_ATOL, (0.01, None))
fp32_symblock_input_data = []
for tensor in bf16_symblock_input_data:
if mx.nd.get_dtype_name(tensor.dtype) == 'bfloat16':
tensor = tensor.astype('float32')
fp32_symblock_input_data.append(tensor)
try:
bf16_outs = symblock(*bf16_symblock_input_data)
fp32_outs = symblock(*fp32_symblock_input_data)
mx.nd.waitall()
except mx.MXNetError as e:
pytest.fail(str(e))
if not isinstance(bf16_outs, (list, tuple)):
bf16_outs = [bf16_outs]
if not isinstance(fp32_outs, (list, tuple)):
fp32_outs = [fp32_outs]
assert any(mx.nd.get_dtype_name(tensor.dtype) == 'bfloat16'
for tensor in bf16_symblock_input_data + bf16_outs)
assert len(bf16_outs) == len(fp32_outs)
for bf16_out, fp32_out in zip(bf16_outs, fp32_outs):
assert_almost_equal(bf16_out.astype('float32'), fp32_out.astype('float32'), rtol, atol)