blob: ac8601905fc4244b9de9d0faed800dc5f63b2143 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Utils for registering NumPy array function protocol for mxnet.numpy ops."""
import functools
import numpy as _np
from . import numpy as mx_np # pylint: disable=reimported
from .numpy.multiarray import _NUMPY_ARRAY_FUNCTION_DICT, _NUMPY_ARRAY_UFUNC_DICT
def _find_duplicate(strs):
str_set = set()
for s in strs:
if s in str_set:
return s
else:
str_set.add(s)
return None
def _implements(numpy_function):
"""Register an __array_function__ implementation for MyArray objects."""
def decorator(func):
_NUMPY_ARRAY_FUNCTION_DICT[numpy_function] = func
return func
return decorator
def with_array_function_protocol(func):
"""A decorator for functions that expect array function protocol.
The decorated function only runs when NumPy version >= 1.17."""
from distutils.version import LooseVersion
cur_np_ver = LooseVersion(_np.__version__)
np_1_17_ver = LooseVersion('1.17')
@functools.wraps(func)
def _run_with_array_func_proto(*args, **kwargs):
if cur_np_ver >= np_1_17_ver:
try:
func(*args, **kwargs)
except Exception as e:
raise RuntimeError('Running function {} with NumPy array function protocol failed'
' with exception {}'
.format(func.__name__, str(e)))
return _run_with_array_func_proto
def with_array_ufunc_protocol(func):
"""A decorator for functions that expect array ufunc protocol.
The decorated function only runs when NumPy version >= 1.15."""
from distutils.version import LooseVersion
cur_np_ver = LooseVersion(_np.__version__)
np_1_15_ver = LooseVersion('1.15')
@functools.wraps(func)
def _run_with_array_ufunc_proto(*args, **kwargs):
if cur_np_ver >= np_1_15_ver:
try:
func(*args, **kwargs)
except Exception as e:
raise RuntimeError('Running function {} with NumPy array ufunc protocol failed'
' with exception {}'
.format(func.__name__, str(e)))
return _run_with_array_ufunc_proto
_NUMPY_ARRAY_FUNCTION_LIST = [
'all',
'any',
'sometrue',
'argmin',
'argmax',
'around',
'round',
'round_',
'argsort',
'sort',
'append',
'broadcast_arrays',
'broadcast_to',
'clip',
'concatenate',
'copy',
'cumsum',
'diag',
'diagonal',
'diagflat',
'dot',
'expand_dims',
'fix',
'flip',
'flipud',
'fliplr',
'inner',
'insert',
'interp',
'max',
'amax',
'mean',
'min',
'amin',
'nonzero',
'ones_like',
'atleast_1d',
'atleast_2d',
'atleast_3d',
'prod',
'product',
'ravel',
'repeat',
'reshape',
'roll',
'split',
'array_split',
'hsplit',
'vsplit',
'dsplit',
'squeeze',
'stack',
'std',
'sum',
'swapaxes',
'take',
'tensordot',
'tile',
'transpose',
'unique',
'unravel_index',
'flatnonzero',
'diag_indices_from',
'delete',
'var',
'vdot',
'vstack',
'column_stack',
'hstack',
'dstack',
'zeros_like',
'linalg.norm',
'linalg.cholesky',
'linalg.inv',
'linalg.solve',
'linalg.tensorinv',
'linalg.tensorsolve',
'linalg.lstsq',
'linalg.pinv',
'linalg.eigvals',
'linalg.eig',
'linalg.eigvalsh',
'linalg.eigh',
'linalg.qr',
'linalg.matrix_rank',
'shape',
'trace',
'tril',
'triu',
'meshgrid',
'outer',
'kron',
'einsum',
'polyval',
'shares_memory',
'may_share_memory',
'quantile',
'median',
'percentile',
'diff',
'ediff1d',
'resize',
'where',
'full_like',
'bincount',
'empty_like',
'nan_to_num',
'isnan',
'isfinite',
'isposinf',
'isneginf',
'isinf',
'pad',
'cross',
]
@with_array_function_protocol
def _register_array_function():
"""Register __array_function__ protocol for mxnet.numpy operators so that
``mxnet.numpy.ndarray`` can be fed into the official NumPy operators and
dispatched to MXNet implementation.
Notes
-----
According the __array_function__ protocol (see the following reference),
there are three kinds of operators that cannot be dispatched using this
protocol:
1. Universal functions, which already have their own protocol in the official
NumPy package.
2. Array creation functions.
3. Dispatch for methods of any kind, e.g., methods on np.random.RandomState objects.
References
----------
https://numpy.org/neps/nep-0018-array-function-protocol.html
"""
dup = _find_duplicate(_NUMPY_ARRAY_FUNCTION_LIST)
if dup is not None:
raise ValueError('Duplicate operator name {} in _NUMPY_ARRAY_FUNCTION_LIST'.format(dup))
for op_name in _NUMPY_ARRAY_FUNCTION_LIST:
strs = op_name.split('.')
if len(strs) == 1:
mx_np_op = getattr(mx_np, op_name)
onp_op = getattr(_np, op_name)
setattr(mx_np, op_name, _implements(onp_op)(mx_np_op))
elif len(strs) == 2:
mx_np_submodule = getattr(mx_np, strs[0])
mx_np_op = getattr(mx_np_submodule, strs[1])
onp_submodule = getattr(_np, strs[0])
onp_op = getattr(onp_submodule, strs[1])
setattr(mx_np_submodule, strs[1], _implements(onp_op)(mx_np_op))
else:
raise ValueError('Does not support registering __array_function__ protocol '
'for operator {}'.format(op_name))
# https://docs.scipy.org/doc/numpy/reference/ufuncs.html#available-ufuncs
_NUMPY_ARRAY_UFUNC_LIST = [
'abs',
'fabs',
'add',
'arctan2',
'copysign',
'degrees',
'hypot',
'lcm',
'gcd',
# 'ldexp',
'logaddexp',
'subtract',
'multiply',
'floor_divide',
'true_divide',
'negative',
'power',
'mod',
'fmod',
'matmul',
'absolute',
'rint',
'sign',
'exp',
'log',
'log2',
'log10',
'expm1',
'sqrt',
'square',
'cbrt',
'reciprocal',
'invert',
'bitwise_not',
'remainder',
'sin',
'cos',
'tan',
'sinh',
'cosh',
'tanh',
'arcsin',
'arccos',
'arctan',
'arcsinh',
'arccosh',
'arctanh',
'maximum',
'fmax',
'minimum',
'fmin',
'ceil',
'trunc',
'floor',
'bitwise_and',
'bitwise_xor',
'bitwise_or',
'logical_and',
'logical_or',
'logical_xor',
'logical_not',
'equal',
'not_equal',
'less',
'less_equal',
'greater',
'greater_equal',
]
@with_array_ufunc_protocol
def _register_array_ufunc():
"""Register NumPy array ufunc protocol.
References
----------
https://numpy.org/neps/nep-0013-ufunc-overrides.html
"""
dup = _find_duplicate(_NUMPY_ARRAY_UFUNC_LIST)
if dup is not None:
raise ValueError('Duplicate operator name {} in _NUMPY_ARRAY_UFUNC_LIST'.format(dup))
for op_name in _NUMPY_ARRAY_UFUNC_LIST:
try:
mx_np_op = getattr(mx_np, op_name)
_NUMPY_ARRAY_UFUNC_DICT[op_name] = mx_np_op
except AttributeError:
raise AttributeError('mxnet.numpy does not have operator named {}'.format(op_name))
_register_array_function()
_register_array_ufunc()