blob: 05030730806d396089628ca7d56ef1dca29b973f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys, os, logging, functools
import multiprocessing as mp
import mxnet as mx
import numpy as np
import random
import shutil
from mxnet.base import MXNetError
from mxnet.test_utils import environment
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, '../common/'))
sys.path.insert(0, os.path.join(curr_path, '../../../python'))
import models
from contextlib import contextmanager
import pytest
from mxnet.util import TemporaryDirectory
import locale
xfail_when_nonstandard_decimal_separator = pytest.mark.xfail(
locale.localeconv()["decimal_point"] != ".",
reason="Some operators break when the decimal separator is set to anything other than \".\". "
"These operators should be rewritten to utilize the new FFI. Please see #18097 for more "
"information."
)
def assertRaises(expected_exception, func, *args, **kwargs):
try:
func(*args, **kwargs)
except expected_exception as e:
pass
else:
# Did not raise exception
assert False, f"{func.__name__} did not raise {expected_exception.__name__}"
def default_logger():
"""A logger used to output seed information to logs."""
logger = logging.getLogger(__name__)
# getLogger() lookups will return the same logger, but only add the handler once.
if not len(logger.handlers):
handler = logging.StreamHandler(sys.stderr)
handler.setFormatter(logging.Formatter('[%(levelname)s] %(message)s'))
logger.addHandler(handler)
if (logger.getEffectiveLevel() == logging.NOTSET):
logger.setLevel(logging.INFO)
return logger
@contextmanager
def random_seed(seed=None):
"""
Runs a code block with a new seed for np, mx and python's random.
Parameters
----------
seed : the seed to pass to np.random, mx.random and python's random.
To impose rng determinism, invoke e.g. as in:
with random_seed(1234):
...
To impose rng non-determinism, invoke as in:
with random_seed():
...
Upon conclusion of the block, the rng's are returned to
a state that is a function of their pre-block state, so
any prior non-determinism is preserved.
"""
try:
next_seed = np.random.randint(0, np.iinfo(np.int32).max)
if seed is None:
np.random.seed()
seed = np.random.randint(0, np.iinfo(np.int32).max)
logger = default_logger()
logger.debug('Setting np, mx and python random seeds = %s', seed)
np.random.seed(seed)
mx.random.seed(seed)
random.seed(seed)
yield
finally:
# Reinstate prior state of np.random and other generators
np.random.seed(next_seed)
mx.random.seed(next_seed)
random.seed(next_seed)
def _assert_raise_cuxx_version_not_satisfied(min_version, cfg):
def less_than(version_left, version_right):
"""Compares two version strings in the format num(.[num])*"""
if not version_left or not version_right:
return False
left = version_left.split(".")
right = version_right.split(".")
# 0 pad shortest version - e.g.
# less_than("9.1", "9.1.9") == less_than("9.1.0", "9.1.9")
longest = max(len(left), len(right))
left.extend([0] * (longest - len(left)))
right.extend([0] * (longest - len(right)))
# compare each of the version components
for l, r in zip(left, right):
if l == r:
continue
return int(l) < int(r)
return False
def test_helper(orig_test):
@functools.wraps(orig_test)
def test_new(*args, **kwargs):
cuxx_off = os.getenv(cfg['TEST_OFF_ENV_VAR']) == 'true'
cuxx_env_version = os.getenv(cfg['VERSION_ENV_VAR'], None if cuxx_off else cfg['DEFAULT_VERSION'])
cuxx_test_disabled = cuxx_off or less_than(cuxx_env_version, min_version)
if not cuxx_test_disabled or mx.device.current_device().device_type == 'cpu':
orig_test(*args, **kwargs)
else:
pytest.raises((MXNetError, RuntimeError), orig_test, *args, **kwargs)
return test_new
return test_helper
def assert_raises_cudnn_not_satisfied(min_version):
return _assert_raise_cuxx_version_not_satisfied(min_version, {
'TEST_OFF_ENV_VAR': 'CUDNN_OFF_TEST_ONLY',
'VERSION_ENV_VAR': 'CUDNN_VERSION',
'DEFAULT_VERSION': '7.3.1'
})
def assert_raises_cuda_not_satisfied(min_version):
return _assert_raise_cuxx_version_not_satisfied(min_version, {
'TEST_OFF_ENV_VAR': 'CUDA_OFF_TEST_ONLY',
'VERSION_ENV_VAR': 'CUDA_VERSION',
'DEFAULT_VERSION': '10.1'
})
def with_environment(*args_):
"""
Helper function that takes a dictionary of environment variables and their
desired settings and changes the environment in advance of running the
decorated code. The original environment state is reinstated afterwards,
even if exceptions are raised.
"""
def test_helper(orig_test):
@functools.wraps(orig_test)
def test_new(*args, **kwargs):
with environment(*args_):
orig_test(*args, **kwargs)
return test_new
return test_helper
def run_in_spawned_process(func, env, *args):
"""
Helper function to run a test in its own process.
Avoids issues with Singleton- or otherwise-cached environment variable lookups in the backend.
Adds a seed as first arg to propagate determinism.
Parameters
----------
func : function to run in a spawned process.
env : dict of additional environment values to set temporarily in the environment before exec.
args : args to pass to the function.
Returns
-------
Whether the python version supports running the function as a spawned process.
This routine calculates a random seed and passes it into the test as a first argument. If the
test uses random values, it should include an outer 'with random_seed(seed):'. If the
test needs to return values to the caller, consider use of shared variable arguments.
"""
try:
mpctx = mp.get_context('spawn')
except:
print(f'SKIP: python{sys.version_info[0]}.{sys.version_info[1]} lacks the required process fork-exec support ... ',
file=sys.stderr, end='')
return False
else:
seed = np.random.randint(0,1024*1024*1024)
with environment(env):
# Prepend seed as first arg
p = mpctx.Process(target=func, args=(seed,)+args)
p.start()
p.join()
assert p.exitcode == 0, f"Non-zero exit code {p.exitcode} from {func.__name__}()."
return True
def retry(n):
"""Retry n times before failing for stochastic test cases."""
# TODO(szha): replace with flaky
# https://github.com/apache/incubator-mxnet/issues/17803
assert n > 0
def test_helper(orig_test):
@functools.wraps(orig_test)
def test_new(*args, **kwargs):
"""Wrapper for tests function."""
for i in range(n):
try:
orig_test(*args, **kwargs)
return
except AssertionError as e:
if i == n-1:
raise e
mx.nd.waitall()
return test_new
return test_helper