blob: c82e350b1b8b13c0b176a7430dcbe5c7db869022 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""Register backend ops in mxnet.symbol namespace."""
import os as _os
import ctypes
import numpy as _np
from . import _internal
from .. import name as _name, attribute
from ._internal import SymbolBase, _symbol_creator
from ..base import mx_uint, check_call, _LIB, py_str
from ..symbol_doc import _build_doc
from ..base import _Null, _init_op_module, _is_np_op, _output_is_list
from ..name import NameManager
from ..profiler import _current_scope as _profiler_scope
from ..ndarray import get_dtype_name
# pylint: enable=unused-import
def _verify_np_symbol(op_name, func_name, sym):
"""Verify if the sym is a numpy symbol.
Parameters
----------
op_name : str
Operator full name registered in backend.
func_name : str
Operator name exposed to users. This is usually the name by stripping off
the prefix of the full operator names registered in backend.
sym : symbol to be verified
"""
from .numpy._symbol import _Symbol as np_symbol
if not isinstance(sym, np_symbol):
raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
'This is a numpy operator which can only accept '
'MXNet numpy ndarrays, while received a legacy ndarray. '
'Please ensure that you have activated numpy semantics by calling '
'`npx.set_np()` in your code. If you still see this error with numpy '
'semantics activated, please call `as_np_ndarray()` upon the legacy '
'ndarray to convert it to an MXNet numpy ndarray, and then feed the '
'converted array to this operator.'
.format(op_name, func_name))
def _verify_legacy_symbol(op_name, func_name, sym):
"""Verify if the sym is a legacy symbol.
Parameters
----------
op_name : str
Operator full name registered in backend.
func_name : str
Operator name exposed to users. This is usually the name by stripping off
the prefix of the full operator names registered in backend.
sym : symbol to be verified
"""
from .numpy._symbol import _Symbol as np_symbol
if isinstance(sym, np_symbol):
raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
'This is a legacy operator which can only accept '
'legacy ndarrays, while received an MXNet numpy ndarray. '
'Please call `as_nd_ndarray()` upon the numpy ndarray to '
'convert it to a legacy ndarray, and then feed the converted '
'array to this operator.'
.format(op_name, func_name))
def _generate_symbol_function_code(handle, op_name, func_name, signature_only=False):
"""Generate function for symbol op by handle and function name."""
real_name = ctypes.c_char_p()
desc = ctypes.c_char_p()
num_args = mx_uint()
arg_names = ctypes.POINTER(ctypes.c_char_p)()
arg_types = ctypes.POINTER(ctypes.c_char_p)()
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
key_var_num_args = ctypes.c_char_p()
ret_type = ctypes.c_char_p()
check_call(_LIB.MXSymbolGetAtomicSymbolInfo(
handle, ctypes.byref(real_name), ctypes.byref(desc),
ctypes.byref(num_args),
ctypes.byref(arg_names),
ctypes.byref(arg_types),
ctypes.byref(arg_descs),
ctypes.byref(key_var_num_args),
ctypes.byref(ret_type)))
narg = int(num_args.value)
arg_names = [py_str(arg_names[i]) for i in range(narg)]
arg_types = [py_str(arg_types[i]) for i in range(narg)]
key_var_num_args = py_str(key_var_num_args.value)
ret_type = py_str(ret_type.value) if ret_type.value is not None else ''
doc_str = _build_doc(op_name,
py_str(desc.value),
arg_names,
arg_types,
[py_str(arg_descs[i]) for i in range(narg)],
key_var_num_args,
ret_type)
dtype_name = None
arr_name = None
ndsignature = []
signature = []
ndarg_names = []
kwarg_names = []
for i in range(narg):
name, atype = arg_names[i], arg_types[i]
if name == 'dtype':
dtype_name = name
signature.append(f'{name}=_Null')
elif atype.startswith('NDArray') or atype.startswith('Symbol'):
assert not arr_name, \
"Op can only have one argument with variable " \
"size and it must be the last argument."
if atype.endswith('[]'):
ndsignature.append(f'*{name}')
arr_name = name
else:
ndsignature.append(f'{name}=None')
ndarg_names.append(name)
else:
signature.append(f'{name}=_Null')
kwarg_names.append(name)
#signature.append('is_train=False')
signature.append('name=None')
signature.append('attr=None')
signature.append('out=None')
signature.append('**kwargs')
signature = ndsignature + signature
is_np_op = _is_np_op(op_name)
output_is_list = _output_is_list(op_name)
verify_symbol_fn = _verify_np_symbol.__name__ if is_np_op else _verify_legacy_symbol.__name__
code = []
if arr_name:
code.append("""
def %s(*%s, **kwargs):"""%(func_name, arr_name))
if not signature_only:
code.append("""
sym_args = []
for i in {}:
assert isinstance(i, SymbolBase), \\
"Positional arguments must be Symbol instances, " \\
"but got %s"%str(i)
{}('{}', '{}', i)
sym_args.append(i)""".format(arr_name, verify_symbol_fn, op_name, func_name))
if dtype_name is not None:
code.append("""
if '%s' in kwargs:
kwargs['%s'] = get_dtype_name(kwargs['%s'])"""%(dtype_name, dtype_name, dtype_name))
code.append("""
attr = kwargs.pop('attr', None)
kwargs.update(attribute.current().get(attr))
name = kwargs.pop('name', None)
name = _name.current().get(name, '%s')
_ = kwargs.pop('out', None)
keys = []
vals = []
sym_kwargs = dict()
for k, v in kwargs.items():
if isinstance(v, SymbolBase):
sym_kwargs[k] = v
%s('%s', '%s', v)
else:
keys.append(k)
vals.append(v)"""%(func_name.lower(), verify_symbol_fn, op_name, func_name))
if key_var_num_args: # pylint: disable=using-constant-test
code.append("""
if '%s' not in kwargs:
keys.append('%s')
vals.append(len(sym_args) + len(sym_kwargs))"""%(
key_var_num_args, key_var_num_args))
code.append("""
if 'profiler_scope' not in keys:
keys.append('profiler_scope')
vals.append(_profiler_scope.get())
return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name, %s, %s)"""%(
handle.value, str(is_np_op), str(output_is_list)))
else:
code.append("""
def %s(%s):"""%(func_name, ', '.join(signature)))
if not signature_only:
code.append("""
kwargs.update(attribute.current().get(attr))
sym_kwargs = dict()
_keys = []
_vals = []
for _k, _v in kwargs.items():
if isinstance(_v, SymbolBase):
sym_kwargs[_k] = _v
{}('{}', '{}', _v)
else:
_keys.append(_k)
_vals.append(_v)""".format(verify_symbol_fn, op_name, func_name))
# NDArray args
for name in ndarg_names: # pylint: disable=redefined-argument-from-local
code.append("""
if {name} is not None:
assert isinstance({name}, SymbolBase), \\
"Argument {name} must be Symbol instances, but got %s"%str({name})
sym_kwargs['{name}'] = {name}""".format(name=name))
code.append("""
{}('{}', '{}', {name})
""".format(verify_symbol_fn, op_name, func_name, name=name))
# kwargs
for name in kwarg_names: # pylint: disable=redefined-argument-from-local
code.append("""
if %s is not _Null:
_keys.append('%s')
_vals.append(%s)"""%(name, name, name))
# dtype
if dtype_name is not None:
if is_np_op:
code.append("""
if %s is not _Null and %s is not None:
_keys.append('%s')
_vals.append(get_dtype_name(%s))"""%(dtype_name, dtype_name, dtype_name, dtype_name))
else:
code.append("""
if %s is not _Null:
_keys.append('%s')
_vals.append(get_dtype_name(%s))"""%(dtype_name, dtype_name, dtype_name))
code.append("""
name = _name.current().get(name, '%s')
if 'profiler_scope' not in _keys:
_keys.append('profiler_scope')
_vals.append(_profiler_scope.get())
return _symbol_creator(%d, None, sym_kwargs, _keys, _vals, name, %s, %s)"""%(
func_name.lower(), handle.value, str(is_np_op), str(output_is_list)))
if signature_only:
code.append("""
return (0,)""")
doc_str_lines = _os.linesep+''.join([' '+s if s.strip() else s
for s in 'r"""{doc_str}"""'.format(doc_str=doc_str)
.splitlines(True)])
code.insert(1, doc_str_lines)
return ''.join(code), doc_str
def _make_symbol_function(handle, name, func_name):
"""Create a symbol function by handle and function name."""
code, doc_str = _generate_symbol_function_code(handle, name, func_name)
local = {}
exec(code, None, local) # pylint: disable=exec-used
symbol_function = local[func_name]
symbol_function.__name__ = func_name
symbol_function.__doc__ = doc_str
symbol_function.__module__ = 'mxnet.symbol'
return symbol_function
_init_op_module('mxnet', 'symbol', _make_symbol_function)
# Update operator documentation with added float support
# Note that we can only do this after the op module is initialized
# Otherwise the backend operators cannot be found
# pylint: disable=wrong-import-position
from .contrib import adamw_update, mp_adamw_update
from ._internal import _adamw_update, _mp_adamw_update
adamw_update.__doc__ = _adamw_update.__doc__.replace("rescale_grad : Symbol",
"rescale_grad : Symbol or float")
mp_adamw_update.__doc__ = _mp_adamw_update.__doc__.replace("rescale_grad : Symbol",
"rescale_grad : Symbol or float")