blob: c49fa9b49865faf4bb839422c48bc7f30c64e57d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import sys
import mxnet as mx
import numpy as np
from random import randint
import warnings
import collections
import ctypes
import mxnet.contrib.amp as amp
from nose.tools import assert_raises
from mxnet.test_utils import set_default_context, download_model, same_symbol_structure
from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon import SymbolBlock, nn, rnn
from mxnet.contrib.amp import amp
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import with_seed, teardown, assert_raises_cudnn_not_satisfied
sys.path.insert(0, os.path.join(curr_path, '../train'))
from test_bucketing import train_model
set_default_context(mx.gpu(0))
def test_amp_coverage():
conditional = [item[0] for item in amp.lists.symbol.CONDITIONAL_FP32_FUNCS]
# Check for duplicates
for a in [amp.lists.symbol.FP16_FUNCS,
amp.lists.symbol.FP16_FP32_FUNCS,
amp.lists.symbol.FP32_FUNCS,
amp.lists.symbol.WIDEST_TYPE_CASTS,
conditional]:
ret = [item for item, count in collections.Counter(a).items() if count > 1]
assert ret == [], "Elements " + str(ret) + " are duplicated in the AMP lists."
t = []
for a in [amp.lists.symbol.FP16_FUNCS,
amp.lists.symbol.FP16_FP32_FUNCS,
amp.lists.symbol.FP32_FUNCS,
amp.lists.symbol.WIDEST_TYPE_CASTS,
conditional]:
t += a
ret = [item for item, count in collections.Counter(t).items() if count > 1]
assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP list."
# Check the coverage
py_str = lambda x: x.decode('utf-8')
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
mx.base._LIB.MXListAllOpNames(ctypes.byref(size),
ctypes.byref(plist))
op_names = []
for i in range(size.value):
s = py_str(plist[i])
if not s.startswith("_backward") \
and not s.startswith("_contrib_backward_"):
op_names.append(s)
ret1 = set(op_names) - set(t)
if ret1 != set():
warnings.warn("Operators " + str(ret1) + " do not exist in AMP lists (in "
"python/mxnet/contrib/amp/lists/symbol.py) - please add them. "
"""Please follow these guidelines for choosing a proper list:
- if your operator is not to be used in a computational graph
(e.g. image manipulation operators, optimizers) or does not have
inputs, put it in FP16_FP32_FUNCS list,
- if your operator requires FP32 inputs or is not safe to use with lower
precision, put it in FP32_FUNCS list,
- if your operator supports both FP32 and lower precision, has
multiple inputs and expects all inputs to be of the same
type, put it in WIDEST_TYPE_CASTS list,
- if your operator supports both FP32 and lower precision and has
either a single input or supports inputs of different type,
put it in FP16_FP32_FUNCS list,
- if your operator is both safe to use in lower precision and
it is highly beneficial to use it in lower precision, then
put it in FP16_FUNCS (this is unlikely for new operators)
- If you are not sure which list to choose, FP32_FUNCS is the
safest option""")
@with_seed()
def test_amp_conversion():
def check_amp_convert_symbol():
x = mx.sym.var("x")
y = mx.sym.var("y")
z = mx.sym.FullyConnected(x, y, num_hidden=10, no_bias=True)
siny = mx.sym.sin(y)
res = z + siny
# Compare symbols with similar computation graphs created using convert_symbol and manually.
res_converted = amp.convert_symbol(res, target_dtype="float16",
target_dtype_ops=["FullyConnected"],
fp32_ops=["sin"])
x_fp16 = mx.sym.amp_cast(x, dtype="float16")
y_fp16 = mx.sym.amp_cast(y, dtype="float16")
amp_casted_siny = mx.sym.sin(mx.sym.amp_cast(y, dtype="float32"))
z = mx.sym.FullyConnected(x_fp16, y_fp16, num_hidden=10, no_bias=True)
outs = mx.sym.amp_multicast(z, amp_casted_siny, num_outputs=2)
res_expected = outs[0] + outs[1]
assert same_symbol_structure(res_converted, res_expected), \
"convert_symbol generating wrong computation graph"
# convert_symbol called with incorrect inputs
assert_raises(AssertionError, amp.convert_symbol, res,
target_dtype="float16", target_dtype_ops=["FullyConnected"],
fp32_ops=["elemwise_add"])
assert_raises(AssertionError, amp.convert_symbol, res,
target_dtype="float16", target_dtype_ops=["FullyConnected"],
fp32_ops=["Activation"],
conditional_fp32_ops=[('Activation', 'act_type', ['selu'])])
assert_raises(AssertionError, amp.convert_symbol, res,
target_dtype="float16", target_dtype_ops=["Activation"],
fp32_ops=["Activation"],
conditional_fp32_ops=[('Activation', 'act_type', ['selu'])])
assert_raises(AssertionError, amp.convert_symbol, res,
target_dtype="float16", target_dtype_ops=["FullyConnected"],
fp32_ops=["FullyConnected"])
# Test for op in conditional ops with condition not satisfied
x = mx.sym.var("x")
y = mx.sym.var("y")
fc_cond = mx.sym.FullyConnected(x, y, num_hidden=10, no_bias=True)
res_converted = amp.convert_symbol(fc_cond, target_dtype="float16",
target_dtype_ops=[],
fp32_ops=["sin"],
conditional_fp32_ops=[("FullyConnected", "no_bias", ["False"])])
res_expected = mx.sym.FullyConnected(x, y, num_hidden=10, no_bias=True)
assert same_symbol_structure(res_converted, res_expected), \
"convert_symbol generating wrong computation graph when conditional ops is used"
# Test for op in conditional ops with condition satisfied
res_converted = amp.convert_symbol(fc_cond, target_dtype="float16", target_dtype_ops=[],
fp32_ops=["sin"],
conditional_fp32_ops=[("FullyConnected", "no_bias", ["True"])])
x_fp32 = mx.sym.amp_cast(x, dtype="float32")
y_fp32 = mx.sym.amp_cast(y, dtype="float32")
res_expected = mx.sym.FullyConnected(x_fp32, y_fp32, num_hidden=10, no_bias=True)
assert same_symbol_structure(res_converted, res_expected), \
"convert_symbol generating wrong computation graph when conditional ops used with satisfying condition"
# Test with a real world model, default inputs for convert_symbol
dir_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(dir_path, 'model')
if not os.path.isdir(model_path):
os.mkdir(model_path)
prefix, epoch = download_model("imagenet1k-resnet-18", dst_dir=model_path)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
inputs = {}
inputs['data'] = mx.nd.ones((1, 3, 224, 224))
inputs.update(arg_params)
converted_sym = amp.convert_symbol(sym)
exe = converted_sym.simple_bind(mx.gpu(0), data=(1, 3, 224, 224), grad_req='null')
exe.forward(is_train=False, **inputs)
exe.outputs[0].asnumpy()
inputs2 = {}
inputs2['data'] = mx.nd.ones((1, 3, 224, 224))
inputs2['fc1_weight'] = inputs['fc1_weight'].astype(np.float16)
inputs2['fc1_bias'] = inputs['fc1_bias'].astype(np.float16)
# Test with a real world model, tweak inputs for convert_symbol
converted_sym = amp.convert_symbol(sym, target_dtype="float16",
target_dtype_ops=["Convolution"], data_names=["data"],
cast_optional_params=True)
converted_sym2 = amp.convert_symbol(sym, target_dtype="float16",
target_dtype_ops=["Convolution"], data_names=["data"],
cast_optional_params=False)
exe = converted_sym.simple_bind(mx.gpu(0), data=(1, 3, 224, 224), grad_req='null')
exe2 = converted_sym2.simple_bind(mx.gpu(), data=(1, 3, 224, 224), grad_req='null')
converted_args = converted_sym.list_arguments()
converted_auxs = converted_sym.list_auxiliary_states()
for i, key in enumerate(exe.arg_arrays):
if converted_args[i] in arg_params:
arg_params[converted_args[i]] = arg_params[converted_args[i]].astype(exe.arg_arrays[i].dtype)
for i, key in enumerate(exe.aux_arrays):
if converted_auxs[i] in aux_params:
aux_params[converted_auxs[i]] = aux_params[converted_auxs[i]].astype(exe.aux_arrays[i].dtype)
inputs2.update(arg_params)
exe.forward(is_train=False, **inputs2)
exe.outputs[0].wait_to_read()
inputs['fc1_weight'] = inputs['fc1_weight'].astype(np.float16)
inputs['fc1_bias'] = inputs['fc1_bias'].astype(np.float16)
exe2.forward(is_train=False, **inputs)
exe2.outputs[0].wait_to_read()
def check_amp_convert_model():
# Test with real world model, default inputs for convert_model
dir_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(dir_path, 'model')
if not os.path.isdir(model_path):
os.mkdir(model_path)
prefix, epoch = download_model("imagenet1k-resnet-18", dst_dir=model_path)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
# Test with real world model, tweak inputs for convert_model
result_sym, result_arg_params, result_aux_params = amp.convert_model(sym,
arg_params,
aux_params,
target_dtype="float16",
target_dtype_ops=["Convolution"])
mod = mx.mod.Module(result_sym, data_names=["data"], label_names=["softmax_label"], context=mx.gpu())
mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]])
mod.set_params(result_arg_params, result_aux_params)
mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))],
label=[mx.nd.ones((1,))]))
mod.get_outputs()[0].asnumpy()
assert mod._arg_params["stage2_unit1_conv2_weight"].dtype == np.float32
# Call convert_model with cast_optional_params set to True
result_sym, result_arg_params, result_aux_params = amp.convert_model(sym,
arg_params,
aux_params,
target_dtype="float16",
target_dtype_ops=["Convolution"], cast_optional_params=True)
mod = mx.mod.Module(result_sym, data_names=["data"], label_names=["softmax_label"], context=mx.gpu())
mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]])
mod.set_params(result_arg_params, result_aux_params)
mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))],
label=[mx.nd.ones((1,))]))
mod.get_outputs()[0].asnumpy()
assert mod._arg_params["stage2_unit1_conv2_weight"].dtype == np.float16
def check_amp_convert_hybrid_block():
# Test conversion for hybrid block on CPU
model_cpu = get_model("resnet50_v1")
model_cpu.collect_params().initialize(ctx=mx.cpu())
model_cpu.hybridize()
model_cpu(mx.nd.random.uniform(0, 1, shape=(1, 3, 224, 224), ctx=mx.cpu()))
converted_model_cpu = amp.convert_hybrid_block(model_cpu)
# Test with real world model, default inputs for convert_hybrid_block
model = get_model("resnet50_v1")
model.collect_params().initialize(ctx=mx.gpu())
model.hybridize()
model(mx.nd.zeros((1, 3, 224, 224)))
converted_model = amp.convert_hybrid_block(model)
result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224),
dtype=np.float32))
result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224),
dtype=np.float32))
# Test with real world model, tweak inputs for convert_hybrid_block
converted_model = amp.convert_hybrid_block(model, target_dtype="float16",
target_dtype_ops=["Convolution"])
result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224),
dtype=np.float32))
result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224),
dtype=np.float32))
# Check symbolic block
dir_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(dir_path, 'model')
if not os.path.isdir(model_path):
os.mkdir(model_path)
prefix, epoch = download_model("imagenet1k-resnet-18", dst_dir=model_path)
net = SymbolBlock.imports(os.path.join(model_path, "imagenet1k-resnet-18-symbol.json"),
input_names=["data", "softmax_label"],
param_file=os.path.join(model_path, "imagenet1k-resnet-18-0000.params"))
net.collect_params().reset_ctx(ctx=mx.gpu())
net.hybridize()
net(mx.nd.zeros((1, 3, 224, 224)), mx.nd.zeros((1,)))
converted_model = amp.convert_hybrid_block(net)
result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224)), mx.nd.zeros((1,)))
result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224)), mx.nd.zeros((1,)))
# Check symbolic block, tweaked inputs
converted_model = amp.convert_hybrid_block(net, target_dtype="float16", target_dtype_ops=["Convolution"])
result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224)), mx.nd.zeros((1, )))
result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224)), mx.nd.zeros((1, )))
params = converted_model.collect_params()
assert params["stage2_unit1_conv2_weight"].dtype == np.float32
# Pass cast_optional_params as True to convert_hybrid_block
converted_model = amp.convert_hybrid_block(net, target_dtype="float16", target_dtype_ops=["Convolution"],
cast_optional_params=True)
params = converted_model.collect_params()
assert params["stage2_unit1_conv2_weight"].dtype == np.float16
def check_amp_convert_bucketing_module():
model = train_model(context=mx.current_context())
result_model = amp.convert_bucketing_module(model)
val_sent = []
batch_size = 128
invalid_label = -1
num_sentence = 1000
buckets = [5, 10, 20, 30, 40]
len_vocab = 50
for _ in range(num_sentence):
len_sentence = randint(6, max(buckets)-1) # leave out the two last buckets empty
val_sentence = []
for _ in range(len_sentence):
val_sentence.append(randint(1, len_vocab))
val_sent.append(val_sentence)
data_val = mx.rnn.BucketSentenceIter(val_sent, batch_size, buckets=buckets,
invalid_label=invalid_label)
result_model.bind(data_val.provide_data, data_val.provide_label, for_training=False)
result_model.score(data_val, mx.metric.Perplexity(invalid_label),
batch_end_callback=mx.callback.Speedometer(batch_size, 1))
# AMP conversion with cast_optional_params set to true
# Flaky test when cast_optional_params set to True : https://github.com/apache/incubator-mxnet/issues/16030
'''
result_model = amp.convert_bucketing_module(model, cast_optional_params=True)
result_model.bind(data_val.provide_data, data_val.provide_label, for_training=False)
result_model.score(data_val, mx.metric.Perplexity(invalid_label),
batch_end_callback=mx.callback.Speedometer(batch_size, 1))
'''
with mx.Context(mx.gpu(0)):
check_amp_convert_symbol()
check_amp_convert_model()
check_amp_convert_hybrid_block()
check_amp_convert_bucketing_module()
@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_amp_conversion_rnn():
with mx.Context(mx.gpu(0)):
model = nn.HybridSequential()
model.add(rnn.LSTM(hidden_size=10, num_layers=2, bidirectional=True))
model.add(nn.Dense(2))
model.initialize()
model.hybridize()
out = model(mx.nd.ones((2, 3, 4)))
new_model = amp.convert_hybrid_block(model)
out2 = new_model(mx.nd.ones((2, 3, 4)))
mx.test_utils.assert_almost_equal(out.asnumpy(), out2.asnumpy(), atol=1e-2, rtol=1e-2)
@with_seed()
def test_module_backward_compatibility():
channel_num = 10
conv_layer_filter_dims = [2, 3]
conv_layer_strides = [1, 1]
dimension = 5
data_len = 10
data = mx.sym.var("data")
conv = mx.sym.Convolution(data,
num_filter=channel_num,
kernel=tuple(conv_layer_filter_dims),
stride=tuple(conv_layer_strides))
bn = mx.sym.BatchNorm(conv,
eps=0.001,
momentum=0.9,
fix_gamma=False,
use_global_stats=False,
output_mean_var=False,
name="conv0_batchnorm")
fc = mx.sym.FullyConnected(bn, num_hidden=10, name="fullyconnected")
mod = mx.mod.Module(fc, data_names=["data"], context=mx.gpu(0))
mod.bind(data_shapes=[['data', (1, 3, 224, 224)]])
mod.init_params()
arg_params, aux_params = mod.get_params()
for param_key, param_val in arg_params.items():
assert param_val.dtype == np.float32, "Incorrect inference type for arg_params," \
"please check simple_bind for module executor"
for param_key, param_val in aux_params.items():
assert param_val.dtype == np.float32, "Incorrect inference type for aux_params," \
"please check simple_bind for module executor"
sym, arg_params, aux_params = amp.convert_model(mod._symbol, mod._arg_params, mod._aux_params, target_dtype_ops=["Convolution"])
mod = mx.mod.Module(sym, data_names=["data"], context=mx.gpu(0))
mod.bind(data_shapes=[['data', (1, 3, 224, 224)]])
mod.set_params(arg_params, aux_params)
assert arg_params["fullyconnected_weight"].dtype == np.float16, \
"Module API is overwriting the inferred dtype for a mixed precision model"
@with_seed()
def test_fp16_casting():
data = mx.sym.var("data")
out1 = mx.sym.amp_cast(data, dtype="float16")
out2 = mx.sym.amp_cast(data, dtype="float32")
out3 = mx.sym.amp_cast(data, dtype="float16")
# When two ops from data, with different dtypes,
# data should be float32
res = mx.sym.Group([out1, out2])
final_res = amp.convert_symbol(res, data_names=[], cast_optional_params=True)
exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2))
assert exe.arg_arrays[0].dtype == np.float32
# When two ops from data, both casted to float16,
# data should be float16
res = mx.sym.Group([out1, out3])
final_res = amp.convert_symbol(res, data_names=[], cast_optional_params=True)
exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2))
assert exe.arg_arrays[0].dtype == np.float16
# AMP Multicast test where one node is float32, another is float16
data = mx.sym.var("data", dtype=np.float32)
data2 = mx.sym.var("data2", dtype=np.float16)
out4 = mx.sym.amp_multicast(data, data2, num_outputs=2)
final_res = amp.convert_symbol(out4, cast_optional_params=True)
exe = final_res.simple_bind(ctx=mx.gpu(), data2=(1, 2), data=(1, 2))
assert exe.arg_arrays[0].dtype == np.float16
# AMP Multicast test where two non input nodes are float16,
# and one input node is float32
data = mx.sym.var("data", dtype=np.float32)
data2 = mx.sym.var("data2", dtype=np.float16)
data3 = mx.sym.var("data3", dtype=np.float16)
out5 = mx.sym.amp_multicast(data,
mx.sym.elemwise_add(data2, data3),
num_outputs=2)
final_res = amp.convert_symbol(out5, target_dtype_ops=[],
fp32_ops=[], cast_optional_params=True)
exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2), data3=(1, 2))
assert exe.arg_arrays[0].dtype == np.float32
# AMP Multicast test where three input nodes one fp16, one fp32
# one unknown
data = mx.sym.var("data", dtype=np.float16)
data2 = mx.sym.var("data2", dtype=np.float32)
data3 = mx.sym.var("data3")
out6 = mx.sym.amp_multicast(data, data2, data3, num_outputs=3)
final_res = amp.convert_symbol(out6, target_dtype_ops=[],
fp32_ops=[], cast_optional_params=True)
exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2),
data3=(1, 2))
assert exe.arg_arrays[2].dtype == np.float32
# Input node to amp_multicast and amp_cast, if dtypes conflict
# and input node is already fp16, it should still be fp16
data = mx.sym.var("data", dtype=np.float16)
data2 = mx.sym.var("data2", dtype=np.float32)
out7 = mx.sym.Group([mx.sym.amp_multicast(data, data2, num_outputs=2), mx.sym.amp_cast(data, dtype="float16")])
final_res = amp.convert_symbol(out7, target_dtype_ops=[],
fp32_ops=[], cast_optional_params=True)
exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
assert exe.arg_arrays[0].dtype == np.float16
# Input node to amp_multicast and amp_cast, if dtypes conflict
# and input node is already fp32, it should be changed to fp16
data = mx.sym.var("data", dtype=np.float32)
data2 = mx.sym.var("data2", dtype=np.float16)
out8 = mx.sym.Group([mx.sym.amp_multicast(data, data2, num_outputs=2), mx.sym.amp_cast(data, dtype="float16")])
final_res = amp.convert_symbol(out8, target_dtype_ops=[],
fp32_ops=[], cast_optional_params=True)
exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
assert exe.arg_arrays[0].dtype == np.float16
if __name__ == '__main__':
import nose
nose.runmodule()