blob: 5e22db9270317d589ea0d90f6e6532021434141b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=
"""Parallelization utility optimizer."""
__all__ = ['split_data', 'split_and_load', 'clip_global_norm',
'check_sha1', 'download', 'replace_file']
import os
import sys
import hashlib
import uuid
import warnings
import collections
import weakref
import requests
import numpy as np
from .. import ndarray
from ..util import is_np_shape, is_np_array, TemporaryDirectory
from .. import numpy as _mx_np # pylint: disable=reimported
def split_data(data, num_slice, batch_axis=0, even_split=True):
"""Splits an NDArray into `num_slice` slices along `batch_axis`.
Usually used for data parallelism where each slices is sent
to one device (i.e. GPU).
Parameters
----------
data : NDArray
A batch of data.
num_slice : int
Number of desired slices.
batch_axis : int, default 0
The axis along which to slice.
even_split : bool, default True
Whether to force all slices to have the same number of elements.
If `True`, an error will be raised when `num_slice` does not evenly
divide `data.shape[batch_axis]`.
Returns
-------
list of NDArray
Return value is a list even if `num_slice` is 1.
"""
size = data.shape[batch_axis]
if even_split and size % num_slice != 0:
raise ValueError(
f"data with shape {str(data.shape)} cannot be evenly split into {num_slice} slices " \
f"along axis {batch_axis}. Use a batch size that's multiple of {num_slice} " \
f"or set even_split=False to allow uneven partitioning of data.")
n_each_section, extras = divmod(size, num_slice)
section_sizes = [0] + (extras * [n_each_section + 1] +
(num_slice - extras) * [n_each_section])
div_points = np.array(section_sizes).cumsum()
if is_np_array():
slices = _mx_np.split(data, indices_or_sections=list(div_points[1: -1]), axis=batch_axis)
else:
slices = []
for i in range(num_slice):
st = div_points[i]
end = div_points[i + 1]
slices.append(ndarray.slice_axis(data, axis=batch_axis, begin=st, end=end))
return slices
def split_and_load(data, ctx_list, batch_axis=0, even_split=True):
"""Splits an NDArray into `len(ctx_list)` slices along `batch_axis` and loads
each slice to one context in `ctx_list`.
Parameters
----------
data : NDArray or ndarray
A batch of data.
ctx_list : list of Context
A list of Contexts.
batch_axis : int, default 0
The axis along which to slice.
even_split : bool, default True
Whether to force all slices to have the same number of elements.
Returns
-------
list of NDArrays or ndarrays
Each corresponds to a context in `ctx_list`.
"""
array_fn = _mx_np.array if is_np_array() else ndarray.array
if not isinstance(data, ndarray.NDArray):
data = array_fn(data, ctx=ctx_list[0])
if len(ctx_list) == 1:
return [data.as_in_context(ctx_list[0])]
slices = split_data(data, len(ctx_list), batch_axis, even_split)
return [i.as_in_context(ctx) for i, ctx in zip(slices, ctx_list)]
def clip_global_norm(arrays, max_norm, check_isfinite=True):
"""Rescales NDArrays so that the sum of their 2-norm is smaller than `max_norm`.
Parameters
----------
arrays : list of NDArray
max_norm : float
check_isfinite : bool, default True
If True, check that the total_norm is finite (not nan or inf). This
requires a blocking .asscalar() call.
Returns
-------
NDArray or float
Total norm. Return type is NDArray of shape (1,) if check_isfinite is
False. Otherwise a float is returned.
"""
# group arrays by ctx
def group_by_ctx(arr_list):
groups = collections.defaultdict(list)
for arr in arr_list:
ctx = arr.device
groups[ctx].append(arr)
return groups
def multi_sum_sq(*args, ctx=None):
sum = _mx_np.array([0], device=ctx)
for arg in args:
sum += _mx_np.square(arg).sum().item()
return sum
arrays_groups = group_by_ctx(arrays)
all_ctx_sum = _mx_np.array([0])
ctx = arrays[0].device
for group in arrays_groups:
sum_sq = multi_sum_sq(*arrays_groups[group], ctx=ctx)
all_ctx_sum += sum_sq
# global reduce
total_norm = _mx_np.sqrt(all_ctx_sum)
if check_isfinite:
if not np.isfinite(total_norm.item()):
warnings.warn(
UserWarning('nan or inf is detected. '
'Clipping results will be undefined.'), stacklevel=2)
scale = max_norm / (total_norm + 1e-8)
scale = _mx_np.min(_mx_np.concatenate([scale, _mx_np.ones(1, device=ctx)], axis=0))
for arr in arrays:
arr *= scale.item()
if check_isfinite:
return total_norm.item()
else:
return total_norm
def _indent(s_, numSpaces):
"""Indent string
"""
s = s_.split('\n')
if len(s) == 1:
return s_
first = s.pop(0)
s = [first] + [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
return s
def check_sha1(filename, sha1_hash):
"""Check whether the sha1 hash of the file content matches the expected hash.
Parameters
----------
filename : str
Path to the file.
sha1_hash : str
Expected sha1 hash in hexadecimal digits.
Returns
-------
bool
Whether the file content matches the expected hash.
"""
sha1 = hashlib.sha1()
with open(filename, 'rb') as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.update(data)
return sha1.hexdigest() == sha1_hash
if not sys.platform.startswith('win32'):
# refer to https://github.com/untitaker/python-atomicwrites
def replace_file(src, dst):
"""Implement atomic os.replace with linux and OSX.
Parameters
----------
src : source file path
dst : destination file path
"""
try:
os.rename(src, dst)
except OSError:
try:
os.remove(src)
except OSError:
pass
finally:
raise OSError(
'Moving downloaded temp file - {}, to {} failed. \
Please retry the download.'.format(src, dst))
else:
import ctypes
_MOVEFILE_REPLACE_EXISTING = 0x1
# Setting this value guarantees that a move performed as a copy
# and delete operation is flushed to disk before the function returns.
# The flush occurs at the end of the copy operation.
_MOVEFILE_WRITE_THROUGH = 0x8
_windows_default_flags = _MOVEFILE_WRITE_THROUGH
def _str_to_unicode(x):
"""Handle text decoding. Internal use only"""
if not isinstance(x, str):
return x.decode(sys.getfilesystemencoding())
return x
def _handle_errors(rv, src):
"""Handle WinError. Internal use only"""
if not rv:
msg = ctypes.FormatError(ctypes.GetLastError())
# if the MoveFileExW fails(e.g. fail to acquire file lock), removes the tempfile
try:
os.remove(src)
except OSError:
pass
finally:
raise OSError(msg)
def replace_file(src, dst):
"""Implement atomic os.replace with windows.
refer to https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw
The function fails when one of the process(copy, flush, delete) fails.
Parameters
----------
src : source file path
dst : destination file path
"""
_handle_errors(ctypes.windll.kernel32.MoveFileExW(
_str_to_unicode(src), _str_to_unicode(dst),
_windows_default_flags | _MOVEFILE_REPLACE_EXISTING
), src)
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""Download a given URL
Parameters
----------
url : str
URL to download
path : str, optional
Destination path to store downloaded file. By default stores to the
current directory with same name as in url.
overwrite : bool, optional
Whether to overwrite destination file if already exists.
sha1_hash : str, optional
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
but doesn't match.
retries : integer, default 5
The number of times to attempt the download in case of failure or non 200 return codes
verify_ssl : bool, default True
Verify SSL certificates.
Returns
-------
str
The file path of the downloaded file.
"""
if path is None:
fname = url.split('/')[-1]
# Empty filenames are invalid
assert fname, 'Can\'t construct file-name from this URL. ' \
'Please set the `path` option manually.'
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
fname = os.path.join(path, url.split('/')[-1])
else:
fname = path
assert retries >= 0, "Number of retries should be at least 0, currently it's {}".format(
retries)
if not verify_ssl:
warnings.warn(
'Unverified HTTPS request is being made (verify_ssl=False). '
'Adding certificate verification is strongly advised.')
if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname, exist_ok=True)
while retries + 1 > 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
print('Downloading {} from {}...'.format(fname, url))
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError('Failed downloading url {}'.format(url))
# create uuid for temporary files
random_uuid = str(uuid.uuid4())
with open('{}.{}'.format(fname, random_uuid), 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
# if the target file exists(created by other processes)
# and have the same hash with target file
# delete the temporary file
if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
# atmoic operation in the same file system
replace_file('{}.{}'.format(fname, random_uuid), fname)
else:
try:
os.remove('{}.{}'.format(fname, random_uuid))
except OSError:
pass
finally:
warnings.warn(
'File {} exists in file system so the downloaded file is deleted'.format(fname))
if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning(
'File {} is downloaded but the content hash does not match.'
' The repo may be outdated or download may be incomplete. '
'If the "repo_url" is overridden, consider switching to '
'the default repo.'.format(fname))
break
except Exception as e:
retries -= 1
if retries <= 0:
raise e
print('download failed due to {}, retrying, {} attempt{} left'
.format(repr(e), retries, 's' if retries > 1 else ''))
return fname
def _get_repo_url():
"""Return the base URL for Gluon dataset and model repository."""
default_repo = 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/'
repo_url = os.environ.get('MXNET_GLUON_REPO', default_repo)
if repo_url[-1] != '/':
repo_url = repo_url+'/'
return repo_url
def _get_repo_file_url(namespace, filename):
"""Return the URL for hosted file in Gluon repository.
Parameters
----------
namespace : str
Namespace of the file.
filename : str
Name of the file
"""
return '{base_url}{namespace}/{filename}'.format(base_url=_get_repo_url(),
namespace=namespace,
filename=filename)
def _brief_print_list(lst, limit=7):
"""Print at most `limit` elements of list."""
lst = list(lst)
if len(lst) > limit:
return _brief_print_list(lst[:limit//2], limit) + ', ..., ' + \
_brief_print_list(lst[-limit//2:], limit)
return ', '.join([f"'{str(i)}'" for i in lst])
class HookHandle(object):
"""A handle that can attach/detach a hook."""
def __init__(self):
self._hooks_dict_ref = None
self._id = None
def attach(self, hooks_dict, hook):
assert not self._hooks_dict_ref, 'The same handle cannot be attached twice.'
self._id = id(hook)
hooks_dict[self._id] = hook
self._hooks_dict_ref = weakref.ref(hooks_dict)
def detach(self):
hooks_dict = self._hooks_dict_ref()
if hooks_dict is not None and self._id in hooks_dict:
del hooks_dict[self._id]
def __getstate__(self):
return (self._hooks_dict_ref(), self._id)
def __setstate__(self, state):
if state[0] is None:
self._hooks_dict_ref = weakref.ref(collections.OrderedDict())
else:
self._hooks_dict_ref = weakref.ref(state[0])
self._id = state[1]
def __enter__(self):
return self
def __exit__(self, ptype, value, trace):
self.detach()
def shape_is_known(shape):
"""Check whether a shape is completely known with or without np semantics.
Please see the doc of is_np_shape for more details.
"""
if shape is None:
return False
unknown_dim_size = -1 if is_np_shape() else 0
if len(shape) == 0:
return unknown_dim_size == -1
for dim_size in shape:
if dim_size == unknown_dim_size:
return False
assert dim_size > unknown_dim_size, "shape dimension size cannot be less than {}, while " \
"received {}".format(unknown_dim_size, dim_size)
return True
def _check_same_symbol_type(symbols):
"""Check whether all the symbols in the list are of the same type.
Raise type error if the types are different. Return the class of
the symbols."""
from ..symbol.numpy import _Symbol as np_symbol
from ..symbol import Symbol as nd_symbol
is_np_sym = isinstance(symbols[0], np_symbol)
for s in symbols[1:]:
if is_np_sym != isinstance(s, np_symbol):
raise TypeError('Found both classic symbol (mx.sym.Symbol) and numpy symbol '
'(mx.sym.np._Symbol) in outputs. This will prevent you from building '
'a computation graph by grouping them since different types of symbols '
'are not allowed to be grouped in Gluon to form a computation graph. '
'You will need to convert them to the same type of symbols, either '
'classic or numpy following this rule: if you want numpy ndarray '
'output(s) from the computation graph, please convert all the classic '
'symbols in the list to numpy symbols by calling `as_np_ndarray()` '
'on each of them; if you want classic ndarray output(s) from the '
'computation graph, please convert all the numpy symbols in the list '
'to classic symbols by calling `as_nd_ndarray()` on each of them.')
return np_symbol if is_np_sym else nd_symbol
def _check_all_np_ndarrays(out):
"""Check if ndarrays/symbols in out are all np.ndarray/np._Symbol."""
from ..numpy import ndarray as np_ndarray
from ..symbol.numpy import _Symbol as np_symbol
from ..symbol import Symbol as nd_symbol
from ..ndarray import NDArray as nd_ndarray
# pylint: disable=no-else-raise
if isinstance(out, (nd_ndarray, nd_symbol)) and not isinstance(out, (np_ndarray, np_symbol)):
raise TypeError("Block's output ndarrays/symbols must be of type `mxnet.numpy.ndarray`"
" or `mxnet.symbol.numpy._Symbol`, while got output type {}"
.format(str(type(out))))
elif isinstance(out, (list, tuple)):
for i in out:
_check_all_np_ndarrays(i)
# pylint: enable=no-else-raise
def _check_block_input_np_ndarrays(inputs):
"""Check if block's inputs are numpy ndarrays."""
from ..numpy import ndarray as np_ndarray
from ..symbol import Symbol as nd_symbol
from ..ndarray import NDArray as nd_ndarray
# pylint: disable=no-else-raise
if isinstance(inputs, (nd_ndarray, nd_symbol)) and not isinstance(inputs, (np_ndarray)):
raise TypeError("Block's inputs must be of type `mxnet.numpy.ndarray`, "
"while got output type {}"
.format(str(type(inputs))))
elif isinstance(inputs, (list, tuple)):
for i in inputs:
_check_block_input_np_ndarrays(i)
# pylint: enable=no-else-raise
# pylint: disable=too-many-nested-blocks
def split_rnn_params(param, mode, num_layers, input_size, hidden_size, bidirectional=False, projection_size=None):
"""Split rnn layer parameter into weight and bias in different layer.
Parameters
----------
param : ndarray
The parameter of rnn layer.
mode : str
Mode of rnn. Supported modes: rnn_relu, rnn_tanh, lstm, gru
num_layers : int, default 1
Number of recurrent layers.
input_size: int, default 0
The number of expected features in the input x.
If not specified, it will be inferred from input.
hidden_size: int
The number of features in the hidden state h.
bidirectional: bool, default False
If `True`, becomes a bidirectional RNN.
projection_size: int, default None
The number of features after projection.
"""
gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
dir = 2 if bidirectional else 1
param_dict = {}
begin = 0
if not projection_size:
for p in ['weight', 'bias']:
for l in range(num_layers):
for d in ['l', 'r'][:dir]:
for g in ['i2h', 'h2h']:
ni = input_size
if l != 0:
ni = hidden_size * dir
if g == 'h2h':
ni = hidden_size
shape0 = gates * hidden_size
if p == 'weight':
cur_len = shape0 * ni
param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \
param[begin:begin+cur_len].reshape(shape0, ni)
else:
cur_len = shape0
param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \
param[begin:begin+cur_len].reshape(shape0,)
begin += cur_len
else:
for p in ['weight', 'bias']:
for l in range(num_layers):
for d in ['l', 'r'][:dir]:
for g in ['i2h', 'h2h', 'h2r']:
if g != 'h2r' or p != 'bias':
if g == 'h2r':
cur_len = projection_size * hidden_size
param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \
param[begin:begin+cur_len]. \
reshape(projection_size, hidden_size)
else:
ni = input_size
if l != 0:
ni = projection_size * dir
if g == 'h2h':
ni = projection_size
shape0 = gates * hidden_size
if p == 'weight':
cur_len = shape0 * ni
param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \
param[begin:begin+cur_len].reshape(shape0, ni)
else:
cur_len = shape0
param_dict['{}{}_{}_{}'.format(d, l, g, p)] = \
param[begin:begin+cur_len].reshape(shape0,)
begin += cur_len
return param_dict