blob: 0c8ce79a71804671ceee278be44402f801efaef2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
from pathlib import Path
curr_path = Path(__file__).resolve().parent
sys.path.insert(0, str(curr_path.parent))
sys.path.insert(0, str(curr_path.parent/'unittest'))
import mxnet as mx
import pytest
from mxnet import amp
from mxnet.test_utils import set_default_device
from mxnet.gluon import nn, rnn
import amp.common as amp_common_tests
from common import assert_raises_cudnn_not_satisfied
AMP_DTYPE = 'float16'
set_default_device(mx.gpu(0))
def test_fp16_coverage():
amp_common_tests.test_amp_coverage(AMP_DTYPE, 'FP16')
@mx.util.use_np
def test_fp16_basic_use():
amp_common_tests.test_amp_basic_use(AMP_DTYPE)
@mx.util.use_np
def test_fp16_offline_casting():
amp_common_tests.test_amp_offline_casting(AMP_DTYPE)
@mx.util.use_np
def test_fp16_offline_casting_shared_params():
amp_common_tests.test_amp_offline_casting_shared_params(AMP_DTYPE)
@mx.util.use_np
def test_fp16_fp32_ops_order_independence():
amp_common_tests.test_lp16_fp32_ops_order_independence(AMP_DTYPE)
@mx.util.use_np
def test_fp16_test_node_excluding():
amp_common_tests.test_amp_node_excluding(AMP_DTYPE)
@pytest.mark.skip(reason='Error during waitall(). Tracked in #18099')
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_amp_conversion_rnn(amp_tests):
with mx.Device(mx.gpu(0)):
model = nn.HybridSequential()
model.add(rnn.LSTM(hidden_size=10, num_layers=2, bidirectional=True))
model.add(nn.Dense(2))
model.initialize()
model.hybridize()
out = model(mx.nd.ones((2, 3, 4)))
new_model = amp.convert_hybrid_block(model)
out2 = new_model(mx.nd.ones((2, 3, 4)))
mx.test_utils.assert_almost_equal(out.asnumpy(), out2.asnumpy(), atol=1e-2, rtol=1e-2)