| # coding: utf-8 |
| # pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines |
| # pylint: disable=import-error, no-name-in-module |
| """Symbolic configuration API of MXNet.""" |
| from __future__ import absolute_import as _abs |
| |
| import ctypes |
| import warnings |
| from numbers import Number |
| |
| import os as _os |
| import sys as _sys |
| import numpy as _numpy |
| |
| from .base import _LIB, numeric_types |
| from .base import c_array, c_str, mx_uint, py_str, string_types |
| from .base import NDArrayHandle, ExecutorHandle, SymbolHandle, OpHandle |
| from .base import check_call, MXNetError, _Null # pylint: disable=unused-import |
| from .context import Context, cpu |
| from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP |
| from .name import NameManager # pylint: disable=unused-import |
| from .executor import Executor |
| from . import _symbol_internal as _internal |
| from .attribute import AttrScope |
| from .symbol_doc import _build_doc |
| |
| # Use different version of SymbolBase |
| # When possible, use cython to speedup part of computation. |
| try: |
| if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: |
| from ._ctypes.symbol import SymbolBase, _set_symbol_class |
| from ._ctypes.symbol import _symbol_creator # pylint: disable=unused-import |
| elif _sys.version_info >= (3, 0): |
| from ._cy3.symbol import SymbolBase, _set_symbol_class |
| from ._cy3.symbol import _symbol_creator # pylint: disable=unused-import |
| else: |
| from ._cy2.symbol import SymbolBase, _set_symbol_class |
| from ._cy2.symbol import _symbol_creator # pylint: disable=unused-import |
| except ImportError: |
| if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: |
| raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") |
| from ._ctypes.symbol import SymbolBase, _set_symbol_class |
| from ._ctypes.symbol import _symbol_creator # pylint: disable=unused-import |
| |
| _GRAD_REQ_MAP = {'null': 0, 'write': 1, 'add': 3} |
| |
| class Symbol(SymbolBase): |
| """Symbol is symbolic graph of the mxnet.""" |
| # disable dictionary storage, also do not have parent type. |
| # pylint: disable=no-member |
| __slots__ = [] |
| |
| def __repr__(self): |
| """Gets a string representation of the symbol.""" |
| name = self.name |
| if name is None: |
| name = ', '.join([i.name for i in self]) |
| return '<%s group [%s]>' % (self.__class__.__name__, name) |
| else: |
| return '<%s %s>' % (self.__class__.__name__, name) |
| |
| def __iter__(self): |
| """Returns a generator object of symbol. |
| |
| One can loop through the returned object list to get outputs. |
| |
| Example usage: |
| ---------- |
| >>> a = mx.sym.Variable('a') |
| >>> b = mx.sym.Variable('b') |
| >>> c = a+b |
| >>> d = mx.sym.Variable('d') |
| >>> e = d+c |
| >>> out = e.get_children() |
| >>> out |
| <Symbol Grouped> |
| >>> for i in out: |
| ... i |
| ... |
| <Symbol d> |
| <Symbol _plus0> |
| """ |
| return (self[i] for i in self.list_outputs()) |
| |
| def __add__(self, other): |
| """x.__add__(y) <=> x+y |
| |
| Scalar input is supported. |
| Broadcasting is not supported. Use `broadcast_add` instead. """ |
| if isinstance(other, Symbol): |
| return _internal._Plus(self, other) |
| if isinstance(other, Number): |
| return _internal._PlusScalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __radd__(self, other): |
| return self.__add__(other) |
| |
| def __sub__(self, other): |
| """x.__sub__(y) <=> x-y |
| |
| Scalar input is supported. |
| Broadcasting is not supported. Use `broadcast_sub` instead. """ |
| if isinstance(other, Symbol): |
| return _internal._Minus(self, other) |
| if isinstance(other, Number): |
| return _internal._MinusScalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __rsub__(self, other): |
| """x.__rsub__(y) <=> y-x |
| |
| Only `NDArray` is supported for now. |
| |
| Example usage: |
| ---------- |
| >>> x = mx.nd.ones((2,3))*3 |
| >>> y = mx.nd.ones((2,3)) |
| >>> x.__rsub__(y).asnumpy() |
| array([[-2., -2., -2.], |
| [-2., -2., -2.]], dtype=float32) |
| """ |
| if isinstance(other, Number): |
| return _internal._RMinusScalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __mul__(self, other): |
| """x.__mul__(y) <=> x*y |
| |
| Scalar input is supported. |
| Broadcasting is not supported. Use `broadcast_mul` instead. """ |
| if isinstance(other, Symbol): |
| return _internal._Mul(self, other) |
| if isinstance(other, Number): |
| return _internal._MulScalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __rmul__(self, other): |
| return self.__mul__(other) |
| |
| def __div__(self, other): |
| """x.__div__(y) <=> x/y |
| |
| Scalar input is supported. |
| Broadcasting is not supported. Use `broadcast_div` instead. """ |
| if isinstance(other, Symbol): |
| return _internal._Div(self, other) |
| if isinstance(other, Number): |
| return _internal._DivScalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __rdiv__(self, other): |
| """x.__rdiv__(y) <=> y/x |
| |
| Only `NDArray` is supported for now. |
| |
| Example usage: |
| ---------- |
| >>> x = mx.nd.ones((2,3))*3 |
| >>> y = mx.nd.ones((2,3)) |
| >>> x.__rdiv__(y).asnumpy() |
| array([[ 0.33333334, 0.33333334, 0.33333334], |
| [ 0.33333334, 0.33333334, 0.33333334]], dtype=float32) |
| """ |
| if isinstance(other, Number): |
| return _internal._RDivScalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __mod__(self, other): |
| """x.__mod__(y) <=> x%y |
| |
| Scalar input is supported. |
| Broadcasting is not supported. Use `broadcast_mod` instead. """ |
| if isinstance(other, Symbol): |
| return _internal._Mod(self, other) |
| if isinstance(other, Number): |
| return _internal._ModScalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __rmod__(self, other): |
| """x.__rmod__(y) <=> y%x |
| |
| Only `NDArray` is supported for now. |
| |
| Example usage: |
| ---------- |
| >>> x = mx.nd.ones((2,3))*3 |
| >>> y = mx.nd.ones((2,3)) |
| >>> x.__rmod__(y).asnumpy() |
| array([[ 1., 1., 1., |
| [ 1., 1., 1., dtype=float32) |
| """ |
| if isinstance(other, Number): |
| return _internal._RModScalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __truediv__(self, other): |
| return self.__div__(other) |
| |
| def __rtruediv__(self, other): |
| return self.__rdiv__(other) |
| |
| def __pow__(self, other): |
| """x.__pow__(y) <=> x**y |
| |
| Scalar input is supported. |
| Broadcasting is not supported. Use `broadcast_pow` instead. """ |
| if isinstance(other, Symbol): |
| return _internal._Power(self, other) |
| if isinstance(other, Number): |
| return _internal._PowerScalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __neg__(self): |
| """x.__neg__() <=> -x |
| |
| Numerical negative, element-wise. |
| |
| Example usage: |
| ---------- |
| >>> a = mx.sym.Variable('a') |
| >>> a |
| <Symbol a> |
| >>> -a |
| <Symbol _mulscalar0> |
| >>> a_neg = a.__neg__() |
| >>> c = a_neg*b |
| >>> ex = c.eval(ctx=mx.cpu(), a=mx.nd.ones([2,3]), b=mx.nd.ones([2,3])) |
| >>> ex[0].asnumpy() |
| array([[-1., -1., -1.], |
| [-1., -1., -1.]], dtype=float32) |
| """ |
| return self.__mul__(-1.0) |
| |
| def __copy__(self): |
| return self.__deepcopy__(None) |
| |
| def __deepcopy__(self, _): |
| """Returns a deep copy of the input object. |
| |
| This function returns a deep copy of the input object including the current state |
| of all its parameters such as weights, biases, etc. |
| |
| Any changes made to the deep copy do not reflect in the original object. |
| |
| Example usage: |
| ---------- |
| >>> import copy |
| >>> data = mx.sym.Variable('data') |
| >>> data_1 = copy.deepcopy(data) |
| >>> data_1 = 2*data |
| >>> data_1.tojson() |
| >>> data_1 is data # Data got modified |
| False |
| """ |
| handle = SymbolHandle() |
| check_call(_LIB.MXSymbolCopy(self.handle, |
| ctypes.byref(handle))) |
| return Symbol(handle) |
| |
| def __eq__(self, other): |
| """x.__eq__(y) <=> x==y |
| |
| Scalar input is supported. |
| Broadcasting is not supported. Use `broadcast_equal` instead. """ |
| if isinstance(other, Symbol): |
| return _internal._equal(self, other) |
| if isinstance(other, numeric_types): |
| return _internal._equal_scalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __ne__(self, other): |
| """x.__ne__(y) <=> x!=y |
| |
| Scalar input is supported. |
| Broadcasting is not supported. Use `broadcast_not_equal` instead. """ |
| if isinstance(other, Symbol): |
| return _internal._not_equal(self, other) |
| if isinstance(other, numeric_types): |
| return _internal._not_equal_scalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __gt__(self, other): |
| """x.__gt__(y) <=> x>y |
| |
| Scalar input is supported. |
| Broadcasting is not supported. Use `broadcast_greater` instead. """ |
| if isinstance(other, Symbol): |
| return _internal._greater(self, other) |
| if isinstance(other, numeric_types): |
| return _internal._greater_scalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __ge__(self, other): |
| """x.__ge__(y) <=> x>=y |
| |
| Scalar input is supported. |
| Broadcasting is not supported. Use `broadcast_greater_equal` instead. """ |
| if isinstance(other, Symbol): |
| return _internal._greater_equal(self, other) |
| if isinstance(other, numeric_types): |
| return _internal._greater_equal_scalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __lt__(self, other): |
| """x.__lt__(y) <=> x<y |
| |
| Scalar input is supported. |
| Broadcasting is not supported. Use `broadcast_lesser` instead. """ |
| if isinstance(other, Symbol): |
| return _internal._lesser(self, other) |
| if isinstance(other, numeric_types): |
| return _internal._lesser_scalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __le__(self, other): |
| """x.__le__(y) <=> x<=y |
| |
| Scalar input is supported. |
| Broadcasting is not supported. Use `broadcast_lesser_equal` instead. """ |
| if isinstance(other, Symbol): |
| return _internal._lesser_equal(self, other) |
| if isinstance(other, numeric_types): |
| return _internal._lesser_equal_scalar(self, scalar=other) |
| else: |
| raise TypeError('type %s not supported' % str(type(other))) |
| |
| def __getstate__(self): |
| handle = self.handle |
| if handle is not None: |
| return {'handle': self.tojson()} |
| else: |
| return {'handle': None} |
| |
| def __setstate__(self, state): |
| # pylint: disable=assigning-non-slot |
| handle = state['handle'] |
| if handle is not None: |
| json_str = handle |
| handle = SymbolHandle() |
| check_call(_LIB.MXSymbolCreateFromJSON(c_str(json_str), ctypes.byref(handle))) |
| self.handle = handle |
| else: |
| self.handle = None |
| |
| def __call__(self, *args, **kwargs): |
| """Composes symbol using inputs. |
| |
| x.__call__(y, z) <=> x(y,z) |
| |
| This function internally calls `_compose` to compose the symbol and |
| returns the composed symbol. |
| |
| Example usage: |
| ---------- |
| >>> data = mx.symbol.Variable('data') |
| >>> net1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10) |
| >>> net2 = mx.symbol.FullyConnected(name='fc3', num_hidden=10) |
| >>> composed = net2(fc3_data=net1, name='composed') |
| >>> composed |
| <Symbol composed> |
| >>> called = net2.__call__(fc3_data=net1, name='composed') |
| >>> called |
| <Symbol composed> |
| |
| Parameters |
| ---------- |
| args: |
| Positional arguments. |
| |
| kwargs: |
| Keyword arguments. |
| |
| Returns |
| ------- |
| The resulting symbol. |
| """ |
| s = self.__copy__() |
| s._compose(*args, **kwargs) |
| return s |
| |
| def _compose(self, *args, **kwargs): |
| """Composes symbol using inputs. |
| |
| x._compose(y, z) <=> x(y,z) |
| |
| This function mutates the current symbol. |
| |
| Example usage: |
| ---------- |
| >>> data = mx.symbol.Variable('data') |
| >>> net1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10) |
| >>> net2 = mx.symbol.FullyConnected(name='fc3', num_hidden=10) |
| >>> net2 |
| <Symbol fc3> |
| >>> net2._compose(fc3_data=net1, name='composed') |
| >>> net2 |
| <Symbol composed> |
| |
| Parameters |
| ---------- |
| args: |
| Positional arguments. |
| |
| kwargs: |
| Keyword arguments. |
| |
| Returns |
| ------- |
| The resulting symbol. |
| """ |
| name = kwargs.pop('name', None) |
| |
| if name: |
| name = c_str(name) |
| if len(args) != 0 and len(kwargs) != 0: |
| raise TypeError('compose only accept input Symbols \ |
| either as positional or keyword arguments, not both') |
| |
| for arg in args: |
| if not isinstance(arg, Symbol): |
| raise TypeError('Compose expect `Symbol` as arguments') |
| for val in kwargs.values(): |
| if not isinstance(val, Symbol): |
| raise TypeError('Compose expect `Symbol` as arguments') |
| |
| num_args = len(args) + len(kwargs) |
| if len(kwargs) != 0: |
| keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs]) |
| args = c_array(SymbolHandle, [s.handle for s in kwargs.values()]) |
| else: |
| keys = None |
| args = c_array(SymbolHandle, [s.handle for s in args]) |
| check_call(_LIB.MXSymbolCompose( |
| self.handle, name, num_args, keys, args)) |
| |
| def __getitem__(self, index): |
| """x.__getitem__(i) <=> x[i] |
| |
| Returns a sliced view of the input symbol. |
| |
| Example usage: |
| ---------- |
| >>> a = mx.sym.var('a') |
| >>> a.__getitem__(0) |
| <Symbol a> |
| >>> a[0] |
| <Symbol a> |
| |
| Parameters |
| ---------- |
| index : int or str |
| Indexing key |
| |
| """ |
| if isinstance(index, string_types): |
| idx = None |
| for i, name in enumerate(self.list_outputs()): |
| if name == index: |
| if idx is not None: |
| raise ValueError('There are multiple outputs with name \"%s\"' % index) |
| idx = i |
| if idx is None: |
| raise ValueError('Cannot find output that matches name \"%s\"' % index) |
| index = idx |
| if not isinstance(index, int): |
| raise TypeError('Symbol only support integer index to fetch i-th output') |
| if index >= (len(self.list_outputs())): |
| # Important, python determines the end by this exception |
| raise IndexError |
| handle = SymbolHandle() |
| check_call(_LIB.MXSymbolGetOutput( |
| self.handle, mx_uint(index), ctypes.byref(handle))) |
| return Symbol(handle=handle) |
| |
| @property |
| def name(self): |
| """Gets name string from the symbol, this function only works for non-grouped symbol. |
| |
| Returns |
| ------- |
| value : str |
| The name of this symbol, returns ``None`` for grouped symbol. |
| """ |
| ret = ctypes.c_char_p() |
| success = ctypes.c_int() |
| check_call(_LIB.MXSymbolGetName( |
| self.handle, ctypes.byref(ret), ctypes.byref(success))) |
| if success.value != 0: |
| return py_str(ret.value) |
| else: |
| return None |
| |
| def attr(self, key): |
| """Returns the attribute string for corresponding input key from the symbol. |
| |
| This function only works for non-grouped symbols. |
| |
| Example usage: |
| ---------- |
| >>> data = mx.sym.Variable('data', attr={'mood': 'angry'}) |
| >>> data.attr('mood') |
| 'angry' |
| |
| Parameters |
| ---------- |
| key : str |
| The key corresponding to the desired attribute. |
| |
| Returns |
| ------- |
| value : str |
| The desired attribute value, returns ``None`` if the attribute does not exist. |
| """ |
| ret = ctypes.c_char_p() |
| success = ctypes.c_int() |
| check_call(_LIB.MXSymbolGetAttr( |
| self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success))) |
| if success.value != 0: |
| return py_str(ret.value) |
| else: |
| return None |
| |
| def list_attr(self, recursive=False): |
| """Gets all attributes from the symbol. |
| |
| Example usage: |
| ---------- |
| >>> data = mx.sym.Variable('data', attr={'mood': 'angry'}) |
| >>> data.list_attr() |
| {'mood': 'angry'} |
| |
| Returns |
| ------- |
| ret : Dict of str to str |
| A dictionary mapping attribute keys to values. |
| """ |
| if recursive: |
| raise DeprecationWarning("Symbol.list_attr with recursive=True has been deprecated. " |
| "Please use attr_dict instead.") |
| size = mx_uint() |
| pairs = ctypes.POINTER(ctypes.c_char_p)() |
| f_handle = _LIB.MXSymbolListAttrShallow |
| check_call(f_handle(self.handle, ctypes.byref(size), ctypes.byref(pairs))) |
| return {py_str(pairs[i * 2]): py_str(pairs[i * 2 + 1]) for i in range(size.value)} |
| |
| def attr_dict(self): |
| """Recursively gets all attributes from the symbol and its children. |
| |
| Example usage: |
| ---------- |
| >>> a = mx.sym.Variable('a', attr={'a1':'a2'}) |
| >>> b = mx.sym.Variable('b', attr={'b1':'b2'}) |
| >>> c = a+b |
| >>> c.attr_dict() |
| {'a': {'a1': 'a2'}, 'b': {'b1': 'b2'}} |
| |
| Returns |
| ------- |
| ret : Dict of str to dict |
| There is a key in the returned dict for every child with non-empty attribute set. |
| For each symbol, the name of the symbol is its key in the dict |
| and the correspond value is that symbol's attribute list (itself a dictionary). |
| """ |
| size = mx_uint() |
| pairs = ctypes.POINTER(ctypes.c_char_p)() |
| f_handle = _LIB.MXSymbolListAttr |
| check_call(f_handle(self.handle, ctypes.byref(size), ctypes.byref(pairs))) |
| ret = {} |
| for i in range(size.value): |
| name, key = py_str(pairs[i * 2]).split('$') |
| val = py_str(pairs[i * 2 + 1]) |
| if name not in ret: |
| ret[name] = {} |
| ret[name][key] = val |
| return ret |
| |
| def _set_attr(self, **kwargs): |
| """Sets an attribute of the symbol. |
| |
| For example. A._set_attr(foo="bar") adds the mapping ``"{foo: bar}"`` |
| to the symbol's attribute dictionary. |
| |
| Parameters |
| ---------- |
| **kwargs |
| The attributes to set |
| """ |
| for key, value in kwargs.items(): |
| if not isinstance(value, string_types): |
| raise ValueError("Set Attr only accepts string values") |
| check_call(_LIB.MXSymbolSetAttr( |
| self.handle, c_str(key), c_str(str(value)))) |
| |
| def get_internals(self): |
| """Gets a new grouped symbol `sgroup`. The output of `sgroup` is a list of |
| outputs of all of the internal nodes. |
| |
| Consider the following code: |
| |
| Example usage: |
| ---------- |
| >>> a = mx.sym.var('a') |
| >>> b = mx.sym.var('b') |
| >>> c = a + b |
| >>> d = c.get_internals() |
| >>> d |
| <Symbol Grouped> |
| >>> d.list_outputs() |
| ['a', 'b', '_plus4_output'] |
| |
| Returns |
| ------- |
| sgroup : Symbol |
| A symbol group containing all internal and leaf nodes of the computation graph |
| used to compute the symbol. |
| """ |
| handle = SymbolHandle() |
| check_call(_LIB.MXSymbolGetInternals( |
| self.handle, ctypes.byref(handle))) |
| return Symbol(handle=handle) |
| |
| def get_children(self): |
| """Gets a new grouped symbol whose output contains |
| inputs to output nodes of the original symbol. |
| |
| Example usage: |
| ---------- |
| >>> x = mx.sym.Variable('x') |
| >>> y = mx.sym.Variable('y') |
| >>> z = mx.sym.Variable('z') |
| >>> a = y+z |
| >>> b = x+a |
| >>> b.get_children() |
| <Symbol Grouped> |
| >>> b.get_children().list_outputs() |
| ['x', '_plus10_output'] |
| >>> b.get_children().get_children().list_outputs() |
| ['y', 'z'] |
| |
| Returns |
| ------- |
| sgroup : Symbol or None |
| The children of the head node. If the symbol has no |
| inputs then ``None`` will be returned. |
| """ |
| handle = SymbolHandle() |
| check_call(_LIB.MXSymbolGetChildren( |
| self.handle, ctypes.byref(handle))) |
| ret = Symbol(handle=handle) |
| if len(ret.list_outputs()) == 0: |
| return None |
| return ret |
| |
| def list_arguments(self): |
| """Lists all the arguments in the symbol. |
| |
| Example usage: |
| ---------- |
| >>> a = mx.sym.var('a') |
| >>> b = mx.sym.var('b') |
| >>> c = a + b |
| >>> c.list_arguments |
| ['a', 'b'] |
| |
| Returns |
| ------- |
| args : list of string |
| List containing the names of all the arguments required to compute the symbol. |
| """ |
| size = ctypes.c_uint() |
| sarr = ctypes.POINTER(ctypes.c_char_p)() |
| check_call(_LIB.MXSymbolListArguments( |
| self.handle, ctypes.byref(size), ctypes.byref(sarr))) |
| return [py_str(sarr[i]) for i in range(size.value)] |
| |
| def list_outputs(self): |
| """Lists all the outputs in the symbol. |
| |
| Example usage: |
| ---------- |
| >>> a = mx.sym.var('a') |
| >>> b = mx.sym.var('b') |
| >>> c = a + b |
| >>> c.list_outputs() |
| ['_plus12_output'] |
| |
| Returns |
| ------- |
| list of str |
| List of all the outputs. |
| For most symbols, this list contains only the name of this symbol. |
| For symbol groups, this is a list with the names of all symbols |
| in the group. |
| """ |
| size = ctypes.c_uint() |
| sarr = ctypes.POINTER(ctypes.c_char_p)() |
| check_call(_LIB.MXSymbolListOutputs( |
| self.handle, ctypes.byref(size), ctypes.byref(sarr))) |
| return [py_str(sarr[i]) for i in range(size.value)] |
| |
| def list_auxiliary_states(self): |
| """Lists all the auxiliary states in the symbol. |
| |
| Example usage: |
| ---------- |
| >>> a = mx.sym.var('a') |
| >>> b = mx.sym.var('b') |
| >>> c = a + b |
| >>> c.list_auxiliary_states() |
| [] |
| |
| Example of auxiliary states in `BatchNorm`. |
| |
| >>> data = mx.symbol.Variable('data') |
| >>> weight = mx.sym.Variable(name='fc1_weight') |
| >>> fc1 = mx.symbol.FullyConnected(data = data, weight=weight, name='fc1', num_hidden=128) |
| >>> fc2 = mx.symbol.BatchNorm(fc1, name='batchnorm0') |
| >>> fc2.list_auxiliary_states() |
| ['batchnorm0_moving_mean', 'batchnorm0_moving_var'] |
| |
| Returns |
| ------- |
| aux_states : list of str |
| List of the auxiliary states in input symbol. |
| |
| Notes |
| ----- |
| Auxiliary states are special states of symbols that do not correspond to an argument, |
| and are not updated by gradient descent. Common examples of auxiliary states |
| include the `moving_mean` and `moving_variance` in `BatchNorm`. |
| Most operators do not have auxiliary states. |
| """ |
| size = ctypes.c_uint() |
| sarr = ctypes.POINTER(ctypes.c_char_p)() |
| check_call(_LIB.MXSymbolListAuxiliaryStates( |
| self.handle, ctypes.byref(size), ctypes.byref(sarr))) |
| return [py_str(sarr[i]) for i in range(size.value)] |
| |
| def list_inputs(self): |
| """Lists all arguments and auxiliary states of this Symbol. |
| |
| Returns |
| ------- |
| inputs : list of str |
| List of all inputs. |
| |
| Examples |
| -------- |
| >>> bn = mx.sym.BatchNorm(name='bn') |
| >>> bn.list_arguments() |
| ['bn_data', 'bn_gamma', 'bn_beta'] |
| >>> bn.list_auxiliary_states() |
| ['bn_moving_mean', 'bn_moving_var'] |
| >>> bn.list_inputs() |
| ['bn_data', 'bn_gamma', 'bn_beta', 'bn_moving_mean', 'bn_moving_var'] |
| """ |
| size = ctypes.c_uint() |
| sarr = ctypes.POINTER(ctypes.c_char_p)() |
| check_call(_LIB.NNSymbolListInputNames( |
| self.handle, 0, ctypes.byref(size), ctypes.byref(sarr))) |
| return [py_str(sarr[i]) for i in range(size.value)] |
| |
| def infer_type(self, *args, **kwargs): |
| """Infers the type of all arguments and all outputs, given the known types |
| for some arguments. |
| |
| This function takes the known types of some arguments in either positional way |
| or keyword argument way as input. It returns a tuple of `None` values |
| if there is not enough information to deduce the missing types. |
| |
| Inconsistencies in the known types will cause an error to be raised. |
| |
| Example usage: |
| ---------- |
| >>> a = mx.sym.var('a') |
| >>> b = mx.sym.var('b') |
| >>> c = a + b |
| >>> arg_types, out_types, aux_types = c.infer_type(a='float32') |
| >>> arg_types |
| [<type 'numpy.float32'>, <type 'numpy.float32'>] |
| >>> out_types |
| [<type 'numpy.float32'>] |
| >>> aux_types |
| [] |
| |
| Parameters |
| ---------- |
| *args : |
| Type of known arguments in a positional way. |
| Unknown type can be marked as None. |
| |
| **kwargs : |
| Keyword arguments of known types. |
| |
| Returns |
| ------- |
| arg_types : list of numpy.dtype or None |
| List of argument types. |
| The order is same as the order of list_arguments(). |
| out_types : list of numpy.dtype or None |
| List of output types. |
| The order is same as the order of list_outputs(). |
| aux_types : list of numpy.dtype or None |
| List of auxiliary state types. |
| The order is same as the order of list_auxiliary_states(). |
| """ |
| # pylint: disable=too-many-locals |
| if len(args) != 0 and len(kwargs) != 0: |
| raise ValueError('Can only specify known argument \ |
| types either by positional or kwargs way.') |
| sdata = [] |
| if len(args) != 0: |
| keys = None |
| for s in args: |
| if s is not None: |
| s = _numpy.dtype(s).type |
| if s not in _DTYPE_NP_TO_MX: |
| raise TypeError('Argument need to be one of ' + str(_DTYPE_NP_TO_MX)) |
| sdata.append(_DTYPE_NP_TO_MX[s]) |
| else: |
| sdata.append(-1) |
| else: |
| keys = [] |
| for k, v in kwargs.items(): |
| v = _numpy.dtype(v).type |
| if v in _DTYPE_NP_TO_MX: |
| keys.append(c_str(k)) |
| sdata.append(_DTYPE_NP_TO_MX[v]) |
| arg_type_size = mx_uint() |
| arg_type_data = ctypes.POINTER(ctypes.c_int)() |
| out_type_size = mx_uint() |
| out_type_data = ctypes.POINTER(ctypes.c_int)() |
| aux_type_size = mx_uint() |
| aux_type_data = ctypes.POINTER(ctypes.c_int)() |
| complete = ctypes.c_int() |
| check_call(_LIB.MXSymbolInferType( |
| self.handle, |
| mx_uint(len(sdata)), |
| c_array(ctypes.c_char_p, keys), |
| c_array(ctypes.c_int, sdata), |
| ctypes.byref(arg_type_size), |
| ctypes.byref(arg_type_data), |
| ctypes.byref(out_type_size), |
| ctypes.byref(out_type_data), |
| ctypes.byref(aux_type_size), |
| ctypes.byref(aux_type_data), |
| ctypes.byref(complete))) |
| if complete.value != 0: |
| arg_types = [ |
| _DTYPE_MX_TO_NP[arg_type_data[i]] for i in range(arg_type_size.value)] |
| out_types = [ |
| _DTYPE_MX_TO_NP[out_type_data[i]] for i in range(out_type_size.value)] |
| aux_types = [ |
| _DTYPE_MX_TO_NP[aux_type_data[i]] for i in range(aux_type_size.value)] |
| return (arg_types, out_types, aux_types) |
| else: |
| return (None, None, None) |
| # pylint: enable=too-many-locals |
| |
| def infer_shape(self, *args, **kwargs): |
| """Infers the shapes of all arguments and all outputs given the known shapes of |
| some arguments. |
| |
| This function takes the known shapes of some arguments in either positional way |
| or keyword argument way as input. It returns a tuple of `None` values |
| if there is not enough information to deduce the missing shapes. |
| |
| Example usage: |
| ---------- |
| >>> a = mx.sym.var('a') |
| >>> b = mx.sym.var('b') |
| >>> c = a + b |
| >>> arg_shapes, out_shapes, aux_shapes = c.infer_shape(a=(3,3)) |
| >>> arg_shapes |
| [(3L, 3L), (3L, 3L)] |
| >>> out_shapes |
| [(3L, 3L)] |
| >>> aux_shapes |
| [] |
| >>> c.infer_shape(a=(0,3)) # 0s in shape means unknown dimensions. So, returns None. |
| (None, None, None) |
| |
| Inconsistencies in the known shapes will cause an error to be raised. |
| See the following example: |
| |
| >>> data = mx.sym.Variable('data') |
| >>> out = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=1000) |
| >>> out = mx.sym.Activation(data=out, act_type='relu') |
| >>> out = mx.sym.FullyConnected(data=out, name='fc2', num_hidden=10) |
| >>> weight_shape= (1, 100) |
| >>> data_shape = (100, 100) |
| >>> out.infer_shape(data=data_shape, fc1_weight=weight_shape) |
| Error in operator fc1: Shape inconsistent, Provided=(1,100), inferred shape=(1000,100) |
| |
| Parameters |
| ---------- |
| *args : |
| Shape of arguments in a positional way. |
| Unknown shape can be marked as None. |
| |
| **kwargs : |
| Keyword arguments of the known shapes. |
| |
| Returns |
| ------- |
| arg_shapes : list of tuple or None |
| List of argument shapes. |
| The order is same as the order of list_arguments(). |
| out_shapes : list of tuple or None |
| List of output shapes. |
| The order is same as the order of list_outputs(). |
| aux_shapes : list of tuple or None |
| List of auxiliary state shapes. |
| The order is same as the order of list_auxiliary_states(). |
| """ |
| try: |
| res = self._infer_shape_impl(False, *args, **kwargs) |
| if res[1] is None: |
| arg_shapes, _, _ = self._infer_shape_impl(True, *args, **kwargs) |
| arg_names = self.list_arguments() |
| unknowns = [] |
| for name, shape in zip(arg_names, arg_shapes): |
| if not shape or not _numpy.prod(shape): |
| if len(unknowns) >= 10: |
| unknowns.append('...') |
| break |
| unknowns.append('%s: %s' % (name, str(shape))) |
| warnings.warn( |
| "Cannot decide shape for the following arguments " + |
| "(0s in shape means unknown dimensions). " + |
| "Consider providing them as input:\n\t" + |
| "\n\t".join(unknowns), stacklevel=2) |
| return res |
| except MXNetError: |
| print("infer_shape error. Arguments:") |
| for i, arg in enumerate(args): |
| print(" #%d: %s" % (i, arg)) |
| for k, v in kwargs.items(): |
| print(" %s: %s" % (k, v)) |
| raise |
| |
| def infer_shape_partial(self, *args, **kwargs): |
| """Infers the shape partially. |
| |
| This functions works the same way as `infer_shape`, |
| except that this function can return partial results. |
| |
| In the following example, information about fc2 is not available. So, `infer_shape` |
| will return a tuple of `None` values but `infer_shape_partial` will return partial values. |
| |
| Example usage: |
| ---------- |
| >>> data = mx.sym.Variable('data') |
| >>> prev = mx.sym.Variable('prev') |
| >>> fc1 = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=128) |
| >>> fc2 = mx.sym.FullyConnected(data=prev, name='fc2', num_hidden=128) |
| >>> out = mx.sym.Activation(data=mx.sym.elemwise_add(fc1, fc2), act_type='relu') |
| >>> out.list_arguments() |
| ['data', 'fc1_weight', 'fc1_bias', 'prev', 'fc2_weight', 'fc2_bias'] |
| >>> out.infer_shape(data=(10,64)) |
| (None, None, None) |
| >>> out.infer_shape_partial(data=(10,64)) |
| ([(10L, 64L), (128L, 64L), (128L,), (), (), ()], [(10L, 128L)], []) |
| >>> # infers shape if you give information about fc2 |
| >>> out.infer_shape(data=(10,64), prev=(10,128)) |
| ([(10L, 64L), (128L, 64L), (128L,), (10L, 128L), (128L, 128L), (128L,)], [(10L, 128L)], []) |
| |
| Parameters |
| ---------- |
| *args : |
| Shape of arguments in a positional way. |
| Unknown shape can be marked as None |
| |
| **kwargs : |
| Keyword arguments of known shapes. |
| |
| Returns |
| ------- |
| arg_shapes : list of tuple or None |
| List of argument shapes. |
| The order is same as the order of list_arguments(). |
| out_shapes : list of tuple or None |
| List of output shapes. |
| The order is same as the order of list_outputs(). |
| aux_shapes : list of tuple or None |
| List of auxiliary state shapes. |
| The order is same as the order of list_auxiliary_states(). |
| """ |
| return self._infer_shape_impl(True, *args, **kwargs) |
| |
| def _infer_shape_impl(self, partial, *args, **kwargs): |
| """The actual implementation for calling shape inference API.""" |
| # pylint: disable=too-many-locals |
| if len(args) != 0 and len(kwargs) != 0: |
| raise ValueError('Can only specify known argument \ |
| shapes either by positional or kwargs way.') |
| sdata = [] |
| indptr = [0] |
| if len(args) != 0: |
| keys = None |
| for i, s in enumerate(args): |
| if s is not None: |
| if not isinstance(s, tuple): |
| raise TypeError("Arguments need to be shapes (tuple), " |
| "but argument %d is %s." % (i, type(s))) |
| sdata.extend(s) |
| indptr.append(len(sdata)) |
| else: |
| keys = [] |
| for k, v in kwargs.items(): |
| if not isinstance(v, tuple): |
| raise TypeError("Arguments need to be shapes (tuple), " |
| "but '%s' is %s." % (k, type(v))) |
| keys.append(c_str(k)) |
| sdata.extend(v) |
| indptr.append(len(sdata)) |
| arg_shape_size = mx_uint() |
| arg_shape_ndim = ctypes.POINTER(mx_uint)() |
| arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))() |
| out_shape_size = mx_uint() |
| out_shape_ndim = ctypes.POINTER(mx_uint)() |
| out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))() |
| aux_shape_size = mx_uint() |
| aux_shape_ndim = ctypes.POINTER(mx_uint)() |
| aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))() |
| complete = ctypes.c_int() |
| if partial: |
| infer_func = _LIB.MXSymbolInferShapePartial |
| else: |
| infer_func = _LIB.MXSymbolInferShape |
| check_call(infer_func( |
| self.handle, |
| mx_uint(len(indptr) - 1), |
| c_array(ctypes.c_char_p, keys), |
| c_array(mx_uint, indptr), |
| c_array(mx_uint, sdata), |
| ctypes.byref(arg_shape_size), |
| ctypes.byref(arg_shape_ndim), |
| ctypes.byref(arg_shape_data), |
| ctypes.byref(out_shape_size), |
| ctypes.byref(out_shape_ndim), |
| ctypes.byref(out_shape_data), |
| ctypes.byref(aux_shape_size), |
| ctypes.byref(aux_shape_ndim), |
| ctypes.byref(aux_shape_data), |
| ctypes.byref(complete))) |
| if complete.value != 0: |
| arg_shapes = [ |
| tuple(arg_shape_data[i][:arg_shape_ndim[i]]) for i in range(arg_shape_size.value)] |
| out_shapes = [ |
| tuple(out_shape_data[i][:out_shape_ndim[i]]) for i in range(out_shape_size.value)] |
| aux_shapes = [ |
| tuple(aux_shape_data[i][:aux_shape_ndim[i]]) for i in range(aux_shape_size.value)] |
| return (arg_shapes, out_shapes, aux_shapes) |
| else: |
| return (None, None, None) |
| # pylint: enable=too-many-locals |
| |
| def debug_str(self): |
| """Gets a debug string of symbol. |
| |
| It contains Symbol output, variables and operators in the computation graph |
| with their inputs, variables and attributes. |
| |
| Returns |
| ------- |
| string |
| Debug string of the symbol. |
| |
| Examples |
| -------- |
| >>> a = mx.sym.Variable('a') |
| >>> b = mx.sym.sin(a) |
| >>> c = 2 * a + b |
| >>> d = mx.sym.FullyConnected(data=c, num_hidden=10) |
| >>> d.debug_str() |
| >>> print d.debug_str() |
| Symbol Outputs: |
| output[0]=fullyconnected0(0) |
| Variable:a |
| -------------------- |
| Op:_mul_scalar, Name=_mulscalar0 |
| Inputs: |
| arg[0]=a(0) version=0 |
| Attrs: |
| scalar=2 |
| -------------------- |
| Op:sin, Name=sin0 |
| Inputs: |
| arg[0]=a(0) version=0 |
| -------------------- |
| Op:elemwise_add, Name=_plus0 |
| Inputs: |
| arg[0]=_mulscalar0(0) |
| arg[1]=sin0(0) |
| Variable:fullyconnected0_weight |
| Variable:fullyconnected0_bias |
| -------------------- |
| Op:FullyConnected, Name=fullyconnected0 |
| Inputs: |
| arg[0]=_plus0(0) |
| arg[1]=fullyconnected0_weight(0) version=0 |
| arg[2]=fullyconnected0_bias(0) version=0 |
| Attrs: |
| num_hidden=10 |
| """ |
| debug_str = ctypes.c_char_p() |
| check_call(_LIB.MXSymbolPrint( |
| self.handle, ctypes.byref(debug_str))) |
| return py_str(debug_str.value) |
| |
| def save(self, fname): |
| """Saves symbol to a file. |
| |
| You can also use pickle to do the job if you only work on python. |
| The advantage of `load`/`save` functions is that the file contents are language agnostic. |
| This means the model saved by one language binding can be loaded by a different |
| language binding of `MXNet`. |
| You also get the benefit of being able to directly load/save from cloud storage(S3, HDFS). |
| |
| Parameters |
| ---------- |
| fname : str |
| The name of the file. |
| |
| - "s3://my-bucket/path/my-s3-symbol" |
| - "hdfs://my-bucket/path/my-hdfs-symbol" |
| - "/path-to/my-local-symbol" |
| |
| See Also |
| -------- |
| symbol.load : Used to load symbol from file. |
| """ |
| if not isinstance(fname, string_types): |
| raise TypeError('fname need to be string') |
| check_call(_LIB.MXSymbolSaveToFile(self.handle, c_str(fname))) |
| |
| def tojson(self): |
| """Saves symbol to a JSON string. |
| |
| See Also |
| -------- |
| symbol.load_json : Used to load symbol from JSON string. |
| """ |
| json_str = ctypes.c_char_p() |
| check_call(_LIB.MXSymbolSaveToJSON(self.handle, ctypes.byref(json_str))) |
| return py_str(json_str.value) |
| |
| @staticmethod |
| def _get_ndarray_inputs(arg_key, args, arg_names, allow_missing): |
| """Helper function to get NDArray lists handles from various inputs. |
| |
| Parameters |
| ---------- |
| arg_key : str |
| The name of argument, used for error message. |
| |
| args : list of NDArray or dict of str to NDArray |
| Input arguments to the symbols. |
| If type is list of NDArray, the position is in the same order of arg_names. |
| If type is dict of str to NDArray, then it maps the name of arguments |
| to the corresponding NDArray, |
| |
| args_names : list of string |
| List of argument names. |
| |
| allow_missing : boolean |
| Whether missing argument is allowed. |
| When allowed, the missing handle will be set to None(null) |
| |
| Returns |
| ------- |
| handles : list of NDArrayHandle |
| The positional list of NDArrayHandles generated from input. |
| """ |
| # setup args |
| arg_handles = [] |
| arg_arrays = [] |
| if isinstance(args, list): |
| if len(args) != len(arg_names): |
| raise ValueError('Length of %s does not match the number of arguments' % arg_key) |
| for narr in args: |
| if not isinstance(narr, NDArray): |
| raise TypeError('Only accept list of NDArrays or dict of str to NDArray') |
| arg_handles.append(narr.handle) |
| arg_arrays = args |
| elif isinstance(args, dict): |
| for name in arg_names: |
| if name in args: |
| narr = args[name] |
| if not isinstance(narr, NDArray): |
| raise TypeError('Only accept list of NDArrays or dict of str to NDArray') |
| arg_handles.append(narr.handle) |
| arg_arrays.append(narr) |
| else: |
| if allow_missing: |
| arg_handles.append(None) |
| arg_arrays.append(None) |
| else: |
| raise ValueError('key `%s` is missing in `%s`' % (name, arg_key)) |
| else: |
| raise TypeError('Only accept list of NDArrays or dict of str to NDArray') |
| return c_array(NDArrayHandle, arg_handles), arg_arrays |
| |
| def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, |
| shared_arg_names=None, shared_exec=None, shared_buffer=None, **kwargs): |
| """Bind current symbol to get an executor, allocate all the arguments needed. |
| Allows specifying data types. |
| |
| This function simplifies the binding procedure. You need to specify only input data shapes. |
| Before binding the executor, the function allocates arguments and auxiliary states |
| that were not explicitly specified. Allows specifying data types. |
| |
| Example usage: |
| ---------- |
| >>> x = mx.sym.Variable('x') |
| >>> y = mx.sym.FullyConnected(x, num_hidden=4) |
| >>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req='null') |
| >>> exe.forward() |
| [<NDArray 5x4 @cpu(0)>] |
| >>> exe.outputs[0].asnumpy() |
| array([[ 0., 0., 0., 0.], |
| [ 0., 0., 0., 0.], |
| [ 0., 0., 0., 0.], |
| [ 0., 0., 0., 0.], |
| [ 0., 0., 0., 0.]], dtype=float32) |
| >>> exe.arg_arrays |
| [<NDArray 5x4 @cpu(0)>, <NDArray 4x4 @cpu(0)>, <NDArray 4 @cpu(0)>] |
| >>> exe.grad_arrays |
| [<NDArray 5x4 @cpu(0)>, <NDArray 4x4 @cpu(0)>, <NDArray 4 @cpu(0)>] |
| |
| Parameters |
| ---------- |
| ctx : Context |
| The device context the generated executor to run on. |
| |
| grad_req: string |
| {'write', 'add', 'null'}, or list of str or dict of str to str, optional |
| To specify how we should update the gradient to the `args_grad`. |
| |
| - 'write' means every time gradient is written to specified `args_grad` NDArray. |
| - 'add' means every time gradient is added to the specified NDArray. |
| - 'null' means no action is taken, the gradient may not be calculated. |
| |
| type_dict : Dict of str->numpy.dtype |
| Input type dictionary, name->dtype |
| |
| group2ctx : Dict of string to mx.Context |
| The dict mapping the `ctx_group` attribute to the context assignment. |
| |
| shared_arg_names : List of string |
| The argument names whose `NDArray` of shared_exec can be reused for initializing |
| the current executor. |
| |
| shared_exec : Executor |
| The executor whose arg_arrays, arg_arrays, grad_arrays, and aux_arrays can be |
| reused for initializing the current executor. |
| |
| shared_buffer : Dict of string to `NDArray` |
| The dict mapping argument names to the `NDArray` that can be reused for initializing |
| the current executor. This buffer will be checked for reuse if one argument name |
| of the current executor is not found in `shared_arg_names`. |
| |
| kwargs : Dict of str->shape |
| Input shape dictionary, name->shape |
| |
| Returns |
| ------- |
| executor : mxnet.Executor |
| The generated executor |
| """ |
| num_provided_arg_types = 0 |
| provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names |
| provided_arg_type_data = ctypes.POINTER(mx_uint)() # provided types |
| if type_dict is not None: |
| provided_arg_type_names = [] |
| provided_arg_type_data = [] |
| for k, v in type_dict.items(): |
| v = _numpy.dtype(v).type |
| if v in _DTYPE_NP_TO_MX: |
| provided_arg_type_names.append(c_str(k)) |
| provided_arg_type_data.append(ctypes.c_int(_DTYPE_NP_TO_MX[v])) |
| num_provided_arg_types = mx_uint(len(provided_arg_type_names)) |
| provided_arg_type_names = c_array(ctypes.c_char_p, provided_arg_type_names) |
| provided_arg_type_data = c_array(ctypes.c_int, provided_arg_type_data) |
| |
| provided_arg_shape_data = [] # shape data |
| # argument shape index in sdata, |
| # e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg |
| provided_arg_shape_idx = [0] |
| provided_arg_shape_names = [] # provided argument names |
| for k, v in kwargs.items(): |
| # if k not in listed_arguments and k not in listed_aux_states: |
| # raise ValueError('arg name %s is not valid', k) |
| if isinstance(v, tuple): |
| provided_arg_shape_names.append(c_str(k)) |
| provided_arg_shape_data.extend(v) |
| provided_arg_shape_idx.append(len(provided_arg_shape_data)) |
| |
| provided_req_type_list_len = 0 |
| provided_grad_req_types = ctypes.POINTER(ctypes.c_char_p)() |
| provided_grad_req_names = ctypes.POINTER(ctypes.c_char_p)() |
| if grad_req is not None: |
| if isinstance(grad_req, string_types): |
| # use provided_req_type_list_len = 0 to indicate this situation |
| provided_req_type_list_len = 0 |
| provided_grad_req_types = [c_str(grad_req)] |
| elif isinstance(grad_req, list): |
| if len(grad_req) == 0: |
| raise RuntimeError('grad_req in simple_bind cannot be an empty list') |
| provided_grad_req_types = [c_str(item) for item in grad_req] |
| provided_req_type_list_len = len(provided_grad_req_types) |
| elif isinstance(grad_req, dict): |
| if len(grad_req) == 0: |
| raise RuntimeError('grad_req in simple_bind cannot be an empty dict') |
| provided_grad_req_names = [] |
| provided_grad_req_types = [] |
| for k, v in grad_req.items(): |
| provided_grad_req_names.append(c_str(k)) |
| provided_grad_req_types.append(c_str(v)) |
| provided_grad_req_names = c_array(ctypes.c_char_p, provided_grad_req_names) |
| provided_req_type_list_len = len(provided_grad_req_types) |
| provided_grad_req_types = c_array(ctypes.c_char_p, provided_grad_req_types) |
| |
| num_ctx_map_keys = mx_uint(0) |
| ctx_map_keys = ctypes.POINTER(ctypes.c_char_p)() |
| ctx_map_dev_types = ctypes.POINTER(ctypes.c_int)() |
| ctx_map_dev_ids = ctypes.POINTER(ctypes.c_int)() |
| if group2ctx is not None: |
| ctx_map_keys = [] |
| ctx_map_dev_types = [] |
| ctx_map_dev_ids = [] |
| for key, val in group2ctx.items(): |
| ctx_map_keys.append(c_str(key)) |
| ctx_map_dev_types.append(ctypes.c_int(val.device_typeid)) |
| ctx_map_dev_ids.append(ctypes.c_int(val.device_id)) |
| num_ctx_map_keys = mx_uint(len(ctx_map_keys)) |
| ctx_map_keys = c_array(ctypes.c_char_p, ctx_map_keys) |
| ctx_map_dev_types = c_array(ctypes.c_int, ctx_map_dev_types) |
| ctx_map_dev_ids = c_array(ctypes.c_int, ctx_map_dev_ids) |
| |
| # prepare param names |
| shared_arg_name_list = [] |
| if shared_arg_names is not None: |
| if not isinstance(shared_arg_names, list): |
| raise ValueError('shared_arg_names in simple_bind must be a list or None') |
| shared_arg_name_list = [c_str(name) for name in shared_arg_names] |
| |
| # prepare shared_buffer |
| if shared_buffer is None: |
| shared_buffer_len = ctypes.c_int(-1) |
| shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)() |
| shared_buffer_handles = ctypes.POINTER(NDArrayHandle)() |
| else: |
| if not isinstance(shared_buffer, dict): |
| raise ValueError('shared_buffer in simple_bind must be dict or None') |
| shared_buffer_names = [] |
| shared_buffer_handles = [] |
| for k, v in shared_buffer.items(): |
| shared_buffer_names.append(c_str(k)) |
| shared_buffer_handles.append(v.handle) |
| shared_buffer_names = c_array(ctypes.c_char_p, shared_buffer_names) |
| shared_buffer_len = ctypes.c_int(len(shared_buffer_handles)) |
| shared_buffer_handles = c_array(NDArrayHandle, shared_buffer_handles) |
| updated_shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)() |
| updated_shared_buffer_handles = ctypes.POINTER(NDArrayHandle)() |
| |
| # prepare shared_exec_handle |
| shared_exec_handle = shared_exec.handle if shared_exec is not None else ExecutorHandle() |
| |
| # prepare current executor handle |
| exe_handle = ExecutorHandle() |
| |
| # prepare current executor's in_args, arg_grads, and aux_states |
| num_in_args = ctypes.c_uint() |
| in_arg_handles = ctypes.POINTER(NDArrayHandle)() |
| arg_grad_handles = ctypes.POINTER(NDArrayHandle)() |
| num_aux_states = ctypes.c_uint() |
| aux_state_handles = ctypes.POINTER(NDArrayHandle)() |
| |
| try: |
| check_call(_LIB.MXExecutorSimpleBind(self.handle, |
| ctypes.c_int(ctx.device_typeid), |
| ctypes.c_int(ctx.device_id), |
| num_ctx_map_keys, |
| ctx_map_keys, |
| ctx_map_dev_types, |
| ctx_map_dev_ids, |
| mx_uint(provided_req_type_list_len), |
| provided_grad_req_names, |
| provided_grad_req_types, |
| mx_uint(len(provided_arg_shape_names)), |
| c_array(ctypes.c_char_p, provided_arg_shape_names), |
| c_array(mx_uint, provided_arg_shape_data), |
| c_array(mx_uint, provided_arg_shape_idx), |
| num_provided_arg_types, |
| provided_arg_type_names, |
| provided_arg_type_data, |
| mx_uint(len(shared_arg_name_list)), |
| c_array(ctypes.c_char_p, shared_arg_name_list), |
| ctypes.byref(shared_buffer_len), |
| shared_buffer_names, |
| shared_buffer_handles, |
| ctypes.byref(updated_shared_buffer_names), |
| ctypes.byref(updated_shared_buffer_handles), |
| ctypes.byref(num_in_args), |
| ctypes.byref(in_arg_handles), |
| ctypes.byref(arg_grad_handles), |
| ctypes.byref(num_aux_states), |
| ctypes.byref(aux_state_handles), |
| shared_exec_handle, |
| ctypes.byref(exe_handle))) |
| except MXNetError as e: |
| error_msg = "simple_bind error. Arguments:\n" |
| for k, v in kwargs.items(): |
| error_msg += "%s: %s\n" % (k, v) |
| error_msg += "%s" % e |
| raise RuntimeError(error_msg) |
| |
| # update shared_buffer |
| if shared_buffer is not None: |
| for i in range(shared_buffer_len.value): |
| k = py_str(updated_shared_buffer_names[i]) |
| v = NDArray(NDArrayHandle(updated_shared_buffer_handles[i])) |
| shared_buffer[k] = v |
| |
| # create in_args, arg_grads, and aux_states for the current executor |
| arg_arrays = [NDArray(NDArrayHandle(in_arg_handles[i])) for i in range(num_in_args.value)] |
| grad_arrays = [NDArray(NDArrayHandle(arg_grad_handles[i])) |
| if arg_grad_handles[i] is not None |
| else None for i in range(num_in_args.value)] |
| aux_arrays = [NDArray(NDArrayHandle(aux_state_handles[i])) |
| for i in range(num_aux_states.value)] |
| |
| executor = Executor(exe_handle, self, ctx, grad_req, group2ctx) |
| executor.arg_arrays = arg_arrays |
| executor.grad_arrays = grad_arrays |
| executor.aux_arrays = aux_arrays |
| return executor |
| |
| def bind(self, ctx, args, args_grad=None, grad_req='write', |
| aux_states=None, group2ctx=None, shared_exec=None): |
| """Binds the current symbol to an executor and returns it. |
| |
| We first declare the computation and then bind to the data to run. |
| This function returns an executor which provides method `forward()` method for evaluation |
| and a `outputs()` method to get all the results. |
| |
| Example usage: |
| ---------- |
| >>> a = mx.sym.Variable('a') |
| >>> b = mx.sym.Variable('b') |
| >>> c = a + b |
| <Symbol _plus1> |
| >>> ex = c.bind(ctx=mx.cpu(), args={'a' : mx.nd.ones([2,3]), 'b' : mx.nd.ones([2,3])}) |
| >>> ex.forward() |
| [<NDArray 2x3 @cpu(0)>] |
| >>> ex.outputs[0].asnumpy() |
| [[ 2. 2. 2.] |
| [ 2. 2. 2.]] |
| |
| Parameters |
| ---------- |
| ctx : Context |
| The device context the generated executor to run on. |
| |
| args : list of NDArray or dict of str to NDArray |
| Input arguments to the symbol. |
| |
| - If the input type is a list of `NDArray`, the order should be same as the order |
| of `list_arguments()`. |
| - If the input type is a dict of str to `NDArray`, then it maps the name of arguments |
| to the corresponding `NDArray`. |
| - In either case, all the arguments must be provided. |
| |
| args_grad : list of NDArray or dict of str to `NDArray`, optional |
| When specified, `args_grad` provides NDArrays to hold |
| the result of gradient value in backward. |
| |
| - If the input type is a list of `NDArray`, the order should be same as the order |
| of `list_arguments()`. |
| - If the input type is a dict of str to `NDArray`, then it maps the name of arguments |
| to the corresponding NDArray. |
| - When the type is a dict of str to `NDArray`, one only need to provide the dict |
| for required argument gradient. |
| Only the specified argument gradient will be calculated. |
| |
| grad_req : {'write', 'add', 'null'}, or list of str or dict of str to str, optional |
| To specify how we should update the gradient to the `args_grad`. |
| |
| - 'write' means everytime gradient is write to specified `args_grad` `NDArray`. |
| - 'add' means everytime gradient is add to the specified NDArray. |
| - 'null' means no action is taken, the gradient may not be calculated. |
| |
| aux_states : list of `NDArray`, or dict of str to `NDArray`, optional |
| Input auxiliary states to the symbol, only needed when the output of |
| `list_auxiliary_states()` is not empty. |
| |
| - If the input type is a list of `NDArray`, the order should be same as the order |
| of `list_auxiliary_states()`. |
| - If the input type is a dict of str to `NDArray`, then it maps the name of |
| `auxiliary_states` to the corresponding `NDArray`, |
| - In either case, all the auxiliary states need to be provided. |
| |
| group2ctx : Dict of string to mx.Context |
| The dict mapping the `ctx_group` attribute to the context assignment. |
| |
| shared_exec : mx.executor.Executor |
| Executor to share memory with. This is intended for runtime reshaping, variable length |
| sequences, etc. The returned executor shares state with `shared_exec`, and should not be |
| used in parallel with it. |
| |
| Returns |
| ------- |
| executor : Executor |
| The generated executor |
| |
| Notes |
| ----- |
| Auxiliary states are the special states of symbols that do not correspond |
| to an argument, and do not have gradient but are still useful |
| for the specific operations. Common examples of auxiliary states include |
| the `moving_mean` and `moving_variance` states in `BatchNorm`. |
| Most operators do not have auxiliary states and in those cases, |
| this parameter can be safely ignored. |
| |
| One can give up gradient by using a dict in `args_grad` and only specify |
| gradient they interested in. |
| """ |
| # pylint: disable=too-many-locals, too-many-branches |
| if not isinstance(ctx, Context): |
| raise TypeError("Context type error") |
| |
| listed_arguments = self.list_arguments() |
| args_handle, args = self._get_ndarray_inputs('args', args, listed_arguments, False) |
| # setup args gradient |
| if args_grad is None: |
| args_grad_handle = c_array(NDArrayHandle, [None] * len(args)) |
| else: |
| args_grad_handle, args_grad = self._get_ndarray_inputs( |
| 'args_grad', args_grad, listed_arguments, True) |
| |
| if aux_states is None: |
| aux_states = [] |
| aux_args_handle, aux_states = self._get_ndarray_inputs( |
| 'aux_states', aux_states, self.list_auxiliary_states(), False) |
| |
| # setup requirements |
| if isinstance(grad_req, string_types): |
| if grad_req not in _GRAD_REQ_MAP: |
| raise ValueError('grad_req must be in %s' % str(_GRAD_REQ_MAP)) |
| reqs_array = c_array( |
| mx_uint, |
| [mx_uint(_GRAD_REQ_MAP[grad_req])] * len(listed_arguments)) |
| elif isinstance(grad_req, list): |
| reqs_array = c_array(mx_uint, [mx_uint(_GRAD_REQ_MAP[item]) for item in grad_req]) |
| elif isinstance(grad_req, dict): |
| req_array = [] |
| for name in listed_arguments: |
| if name in grad_req: |
| req_array.append(mx_uint(_GRAD_REQ_MAP[grad_req[name]])) |
| else: |
| req_array.append(mx_uint(0)) |
| reqs_array = c_array(mx_uint, req_array) |
| |
| ctx_map_keys = [] |
| ctx_map_dev_types = [] |
| ctx_map_dev_ids = [] |
| |
| if group2ctx: |
| for key, val in group2ctx.items(): |
| ctx_map_keys.append(c_str(key)) |
| ctx_map_dev_types.append(ctypes.c_int(val.device_typeid)) |
| ctx_map_dev_ids.append(ctypes.c_int(val.device_id)) |
| |
| handle = ExecutorHandle() |
| shared_handle = shared_exec.handle if shared_exec is not None else ExecutorHandle() |
| check_call(_LIB.MXExecutorBindEX(self.handle, |
| ctypes.c_int(ctx.device_typeid), |
| ctypes.c_int(ctx.device_id), |
| mx_uint(len(ctx_map_keys)), |
| c_array(ctypes.c_char_p, ctx_map_keys), |
| c_array(ctypes.c_int, ctx_map_dev_types), |
| c_array(ctypes.c_int, ctx_map_dev_ids), |
| mx_uint(len(args)), |
| args_handle, |
| args_grad_handle, |
| reqs_array, |
| mx_uint(len(aux_states)), |
| aux_args_handle, |
| shared_handle, |
| ctypes.byref(handle))) |
| executor = Executor(handle, self, ctx, grad_req, group2ctx) |
| executor.arg_arrays = args |
| executor.grad_arrays = args_grad |
| executor.aux_arrays = aux_states |
| return executor |
| |
| def grad(self, wrt): |
| """Gets the autodiff of current symbol. |
| |
| This function can only be used if current symbol is a loss function. |
| |
| .. note:: This function is currently not implemented. |
| |
| Parameters |
| ---------- |
| wrt : Array of String |
| keyword arguments of the symbol that the gradients are taken. |
| |
| Returns |
| ------- |
| grad : Symbol |
| A gradient Symbol with returns to be the corresponding gradients. |
| """ |
| handle = SymbolHandle() |
| c_wrt = c_array(ctypes.c_char_p, [c_str(key) for key in wrt]) |
| check_call(_LIB.MXSymbolGrad(self.handle, |
| mx_uint(len(wrt)), |
| c_wrt, |
| ctypes.byref(handle))) |
| return Symbol(handle) |
| |
| # pylint: enable= no-member |
| |
| def eval(self, ctx=cpu(), **kwargs): |
| """Evaluates a symbol given arguments. |
| |
| The `eval` method combines a call to `bind` (which returns an executor) |
| with a call to `forward` (executor method). |
| For the common use case, where you might repeatedly evaluate with same arguments, |
| eval is slow. |
| In that case, you should call `bind` once and then repeatedly call forward. |
| This function allows simpler syntax for less cumbersome introspection. |
| |
| Example usage: |
| ---------- |
| >>> a = mx.sym.Variable('a') |
| >>> b = mx.sym.Variable('b') |
| >>> c = a + b |
| >>> ex = c.eval(ctx = mx.cpu(), a = mx.nd.ones([2,3]), b = mx.nd.ones([2,3])) |
| >>> ex |
| [<NDArray 2x3 @cpu(0)>] |
| >>> ex[0].asnumpy() |
| array([[ 2., 2., 2.], |
| [ 2., 2., 2.]], dtype=float32) |
| |
| Parameters |
| ---------- |
| ctx : Context |
| The device context the generated executor to run on. |
| |
| kwargs : Keyword arguments of type `NDArray` |
| Input arguments to the symbol. All the arguments must be provided. |
| |
| Returns |
| ---------- |
| result : a list of NDArrays corresponding to the values taken by each symbol when |
| evaluated on given args. When called on a single symbol (not a group), |
| the result will be a list with one element. |
| """ |
| return self.bind(ctx, kwargs).forward() |
| |
| def reshape(self, shape): |
| """Shorthand for mxnet.sym.reshape. |
| |
| Parameters |
| ---------- |
| shape : tuple of int |
| The new shape should not change the array size, namely |
| ``np.prod(new_shape)`` should be equal to ``np.prod(self.shape)``. |
| One shape dimension can be -1. In this case, the value is inferred |
| from the length of the array and remaining dimensions. |
| |
| |
| Returns |
| ------- |
| Symbol |
| A reshaped symbol. |
| """ |
| return reshape(self, shape=shape) |
| |
| def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, init=None, **kwargs): |
| """Creates a symbolic variable with specified name. |
| |
| Example usage: |
| ---------- |
| >>> data = mx.sym.Variable('data', attr={'a': 'b'}) |
| >>> data |
| <Symbol data> |
| |
| Parameters |
| ---------- |
| name : str |
| Variable name. |
| attr : Dict of strings |
| Additional attributes to set on the variable. Format {string : string}. |
| shape : tuple |
| The shape of a variable. If specified, this will be used during the shape inference. |
| If one has specified a different shape for this variable using |
| a keyword argument when calling shape inference, this shape information will be ignored. |
| lr_mult : float |
| The learning rate multiplier for input variable. |
| wd_mult : float |
| Weight decay multiplier for input variable. |
| dtype : str or numpy.dtype |
| The dtype for input variable. If not specified, this value will be inferred. |
| init : initializer (mxnet.init.*) |
| Initializer for this variable to (optionally) override the default initializer. |
| kwargs : Additional attribute variables |
| Additional attributes must start and end with double underscores. |
| |
| Returns |
| ------- |
| variable : Symbol |
| A symbol corresponding to an input to the computation graph. |
| """ |
| if not isinstance(name, string_types): |
| raise TypeError('Expect a string for variable `name`') |
| handle = SymbolHandle() |
| check_call(_LIB.MXSymbolCreateVariable(c_str(name), ctypes.byref(handle))) |
| ret = Symbol(handle) |
| attr = AttrScope.current.get(attr) |
| attr = {} if attr is None else attr |
| if shape is not None: |
| attr['__shape__'] = str(shape) |
| if lr_mult is not None: |
| attr['__lr_mult__'] = str(lr_mult) |
| if wd_mult is not None: |
| attr['__wd_mult__'] = str(wd_mult) |
| if dtype is not None: |
| attr['__dtype__'] = str(_DTYPE_NP_TO_MX[_numpy.dtype(dtype).type]) |
| if init is not None: |
| if not isinstance(init, string_types): |
| init = init.dumps() |
| attr['__init__'] = init |
| for k, v in kwargs.items(): |
| if k.startswith('__') and k.endswith('__'): |
| attr[k] = str(v) |
| else: |
| raise ValueError('Attribute name=%s is not supported.' |
| ' Additional attributes must start and end with double underscores,' |
| ' e.g, __yourattr__' % k) |
| ret._set_attr(**attr) |
| return ret |
| |
| |
| # for back compatibility |
| Variable = var |
| |
| |
| def Group(symbols): |
| """Creates a symbol that contains a collection of other symbols, grouped together. |
| |
| Example usage: |
| ---------- |
| >>> a = mx.sym.Variable('a') |
| >>> b = mx.sym.Variable('b') |
| >>> mx.sym.Group([a,b]) |
| <Symbol Grouped> |
| |
| Parameters |
| ---------- |
| symbols : list |
| List of symbols to be grouped. |
| |
| Returns |
| ------- |
| sym : Symbol |
| A group symbol. |
| """ |
| ihandles = [] |
| for sym in symbols: |
| if not isinstance(sym, Symbol): |
| raise TypeError('Expected a list of symbols as input') |
| ihandles.append(sym.handle) |
| handle = SymbolHandle() |
| check_call(_LIB.MXSymbolCreateGroup( |
| mx_uint(len(ihandles)), |
| c_array(SymbolHandle, ihandles), ctypes.byref(handle))) |
| return Symbol(handle) |
| |
| |
| def load(fname): |
| """Loads symbol from a JSON file. |
| |
| You can also use pickle to do the job if you only work on python. |
| The advantage of load/save is the file is language agnostic. |
| This means the file saved using save can be loaded by other language binding of mxnet. |
| You also get the benefit being able to directly load/save from cloud storage(S3, HDFS). |
| |
| Parameters |
| ---------- |
| fname : str |
| The name of the file, examples: |
| |
| - `s3://my-bucket/path/my-s3-symbol` |
| - `hdfs://my-bucket/path/my-hdfs-symbol` |
| - `/path-to/my-local-symbol` |
| |
| Returns |
| ------- |
| sym : Symbol |
| The loaded symbol. |
| |
| See Also |
| -------- |
| Symbol.save : Used to save symbol into file. |
| """ |
| if not isinstance(fname, string_types): |
| raise TypeError('fname need to be string') |
| handle = SymbolHandle() |
| check_call(_LIB.MXSymbolCreateFromFile(c_str(fname), ctypes.byref(handle))) |
| return Symbol(handle) |
| |
| |
| def load_json(json_str): |
| """Loads symbol from json string. |
| |
| Parameters |
| ---------- |
| json_str : str |
| A JSON string. |
| |
| Returns |
| ------- |
| sym : Symbol |
| The loaded symbol. |
| |
| See Also |
| -------- |
| Symbol.tojson : Used to save symbol into json string. |
| """ |
| if not isinstance(json_str, string_types): |
| raise TypeError('fname required to be string') |
| handle = SymbolHandle() |
| check_call(_LIB.MXSymbolCreateFromJSON(c_str(json_str), ctypes.byref(handle))) |
| return Symbol(handle) |
| |
| |
| # pylint: disable=no-member |
| # pylint: disable=redefined-builtin |
| def pow(base, exp): |
| """Returns element-wise result of base element raised to powers from exp element. |
| |
| Both inputs can be Symbol or scalar number. |
| Broadcasting is not supported. Use `broadcast_pow` instead. |
| |
| Parameters |
| --------- |
| base : Symbol or scalar |
| The base symbol |
| exp : Symbol or scalar |
| The exponent symbol |
| |
| Returns |
| ------- |
| Symbol or scalar |
| The bases in x raised to the exponents in y. |
| |
| Examples |
| -------- |
| >>> mx.sym.pow(2, 3) |
| 8 |
| >>> x = mx.sym.Variable('x') |
| >>> y = mx.sym.Variable('y') |
| >>> z = mx.sym.pow(x, 2) |
| >>> z.eval(x=mx.nd.array([1,2]))[0].asnumpy() |
| array([ 1., 4.], dtype=float32) |
| >>> z = mx.sym.pow(3, y) |
| >>> z.eval(y=mx.nd.array([2,3]))[0].asnumpy() |
| array([ 9., 27.], dtype=float32) |
| >>> z = mx.sym.pow(x, y) |
| >>> z.eval(x=mx.nd.array([3,4]), y=mx.nd.array([2,3]))[0].asnumpy() |
| array([ 9., 64.], dtype=float32) |
| """ |
| if isinstance(base, Symbol) and isinstance(exp, Symbol): |
| return _internal._Power(base, exp) |
| if isinstance(base, Symbol) and isinstance(exp, Number): |
| return _internal._PowerScalar(base, scalar=exp) |
| if isinstance(base, Number) and isinstance(exp, Symbol): |
| return _internal._RPowerScalar(exp, scalar=base) |
| if isinstance(base, Number) and isinstance(exp, Number): |
| return base**exp |
| else: |
| raise TypeError('types (%s, %s) not supported' % (str(type(base)), str(type(exp)))) |
| |
| |
| # pylint: disable=no-member |
| # pylint: disable=redefined-builtin |
| def maximum(left, right): |
| """Returns element-wise maximum of the input elements. |
| |
| Both inputs can be Symbol or scalar number. Broadcasting is not supported. |
| |
| Parameters |
| --------- |
| left : Symbol or scalar |
| First symbol to be compared. |
| right : Symbol or scalar |
| Second symbol to be compared. |
| |
| Returns |
| ------- |
| Symbol or scalar |
| The element-wise maximum of the input symbols. |
| |
| Examples |
| -------- |
| >>> mx.sym.maximum(2, 3.5) |
| 3.5 |
| >>> x = mx.sym.Variable('x') |
| >>> y = mx.sym.Variable('y') |
| >>> z = mx.sym.maximum(x, 4) |
| >>> z.eval(x=mx.nd.array([3,5,2,10]))[0].asnumpy() |
| array([ 4., 5., 4., 10.], dtype=float32) |
| >>> z = mx.sym.maximum(x, y) |
| >>> z.eval(x=mx.nd.array([3,4]), y=mx.nd.array([10,2]))[0].asnumpy() |
| array([ 10., 4.], dtype=float32) |
| """ |
| if isinstance(left, Symbol) and isinstance(right, Symbol): |
| return _internal._Maximum(left, right) |
| if isinstance(left, Symbol) and isinstance(right, Number): |
| return _internal._MaximumScalar(left, scalar=right) |
| if isinstance(left, Number) and isinstance(right, Symbol): |
| return _internal._MaximumScalar(right, scalar=left) |
| if isinstance(left, Number) and isinstance(right, Number): |
| return left if left > right else right |
| else: |
| raise TypeError('types (%s, %s) not supported' % (str(type(left)), str(type(right)))) |
| |
| |
| # pylint: disable=no-member |
| # pylint: disable=redefined-builtin |
| def minimum(left, right): |
| """Returns element-wise minimum of the input elements. |
| |
| Both inputs can be Symbol or scalar number. Broadcasting is not supported. |
| |
| Parameters |
| --------- |
| left : Symbol or scalar |
| First symbol to be compared. |
| right : Symbol or scalar |
| Second symbol to be compared. |
| |
| Returns |
| ------- |
| Symbol or scalar |
| The element-wise minimum of the input symbols. |
| |
| Examples |
| -------- |
| >>> mx.sym.minimum(2, 3.5) |
| 2 |
| >>> x = mx.sym.Variable('x') |
| >>> y = mx.sym.Variable('y') |
| >>> z = mx.sym.minimum(x, 4) |
| >>> z.eval(x=mx.nd.array([3,5,2,10]))[0].asnumpy() |
| array([ 3., 4., 2., 4.], dtype=float32) |
| >>> z = mx.sym.minimum(x, y) |
| >>> z.eval(x=mx.nd.array([3,4]), y=mx.nd.array([10,2]))[0].asnumpy() |
| array([ 3., 2.], dtype=float32) |
| """ |
| if isinstance(left, Symbol) and isinstance(right, Symbol): |
| return _internal._Minimum(left, right) |
| if isinstance(left, Symbol) and isinstance(right, Number): |
| return _internal._MinimumScalar(left, scalar=right) |
| if isinstance(left, Number) and isinstance(right, Symbol): |
| return _internal._MinimumScalar(right, scalar=left) |
| if isinstance(left, Number) and isinstance(right, Number): |
| return left if left < right else right |
| else: |
| raise TypeError('types (%s, %s) not supported' % (str(type(left)), str(type(right)))) |
| |
| |
| # pylint: disable=no-member |
| # pylint: disable=redefined-builtin |
| def hypot(left, right): |
| """Given the "legs" of a right triangle, returns its hypotenuse. |
| |
| Equivalent to :math:`\\sqrt(left^2 + right^2)`, element-wise. |
| Both inputs can be Symbol or scalar number. Broadcasting is not supported. |
| |
| Parameters |
| --------- |
| left : Symbol or scalar |
| First leg of the triangle(s). |
| right : Symbol or scalar |
| Second leg of the triangle(s). |
| |
| Returns |
| ------- |
| Symbol or scalar |
| The hypotenuse of the triangle(s) |
| |
| Examples |
| -------- |
| >>> mx.sym.hypot(3, 4) |
| 5.0 |
| >>> x = mx.sym.Variable('x') |
| >>> y = mx.sym.Variable('y') |
| >>> z = mx.sym.hypot(x, 4) |
| >>> z.eval(x=mx.nd.array([3,5,2]))[0].asnumpy() |
| array([ 5., 6.40312433, 4.47213602], dtype=float32) |
| >>> z = mx.sym.hypot(x, y) |
| >>> z.eval(x=mx.nd.array([3,4]), y=mx.nd.array([10,2]))[0].asnumpy() |
| array([ 10.44030666, 4.47213602], dtype=float32) |
| """ |
| if isinstance(left, Symbol) and isinstance(right, Symbol): |
| return _internal._Hypot(left, right) |
| if isinstance(left, Symbol) and isinstance(right, Number): |
| return _internal._HypotScalar(left, scalar=right) |
| if isinstance(left, Number) and isinstance(right, Symbol): |
| return _internal._HypotScalar(right, scalar=left) |
| if isinstance(left, Number) and isinstance(right, Number): |
| return _numpy.hypot(left, right) |
| else: |
| raise TypeError('types (%s, %s) not supported' % (str(type(left)), str(type(right)))) |
| |
| |
| def zeros(shape, dtype=None, **kwargs): |
| """Returns a new symbol of given shape and type, filled with zeros. |
| |
| Parameters |
| ---------- |
| shape : int or sequence of ints |
| Shape of the new array. |
| dtype : str or numpy.dtype, optional |
| The value type of the inner value, default to ``np.float32``. |
| |
| Returns |
| ------- |
| out : Symbol |
| The created Symbol. |
| """ |
| if dtype is None: |
| dtype = _numpy.float32 |
| return _internal._zeros(shape=shape, dtype=dtype, **kwargs) |
| |
| |
| def ones(shape, dtype=None, **kwargs): |
| """Returns a new symbol of given shape and type, filled with ones. |
| |
| Parameters |
| ---------- |
| shape : int or sequence of ints |
| Shape of the new array. |
| dtype : str or numpy.dtype, optional |
| The value type of the inner value, default to ``np.float32``. |
| |
| Returns |
| ------- |
| out : Symbol |
| The created Symbol |
| """ |
| if dtype is None: |
| dtype = _numpy.float32 |
| return _internal._ones(shape=shape, dtype=dtype, **kwargs) |
| |
| |
| def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None): |
| """Returns evenly spaced values within a given interval. |
| |
| Parameters |
| ---------- |
| start : number |
| Start of interval. The interval includes this value. The default start value is 0. |
| stop : number, optional |
| End of interval. The interval does not include this value. |
| step : number, optional |
| Spacing between values. |
| repeat : int, optional |
| "The repeating time of all elements. |
| E.g repeat=3, the element a will be repeated three times --> a, a, a. |
| dtype : str or numpy.dtype, optional |
| The value type of the inner value, default to ``np.float32``. |
| |
| Returns |
| ------- |
| out : Symbol |
| The created Symbol |
| """ |
| if dtype is None: |
| dtype = _numpy.float32 |
| return _internal._arange(start=start, stop=stop, step=step, repeat=repeat, |
| name=name, dtype=dtype) |
| |
| |
| def _make_atomic_symbol_function(handle, name): |
| """Create an atomic symbol function by handle and funciton name.""" |
| real_name = ctypes.c_char_p() |
| desc = ctypes.c_char_p() |
| num_args = mx_uint() |
| arg_names = ctypes.POINTER(ctypes.c_char_p)() |
| arg_types = ctypes.POINTER(ctypes.c_char_p)() |
| arg_descs = ctypes.POINTER(ctypes.c_char_p)() |
| key_var_num_args = ctypes.c_char_p() |
| ret_type = ctypes.c_char_p() |
| |
| check_call(_LIB.MXSymbolGetAtomicSymbolInfo( |
| handle, ctypes.byref(real_name), ctypes.byref(desc), |
| ctypes.byref(num_args), |
| ctypes.byref(arg_names), |
| ctypes.byref(arg_types), |
| ctypes.byref(arg_descs), |
| ctypes.byref(key_var_num_args), |
| ctypes.byref(ret_type))) |
| narg = int(num_args.value) |
| arg_names = [py_str(arg_names[i]) for i in range(narg)] |
| arg_types = [py_str(arg_types[i]) for i in range(narg)] |
| func_name = name |
| key_var_num_args = py_str(key_var_num_args.value) |
| ret_type = py_str(ret_type.value) if ret_type.value is not None else '' |
| doc_str = _build_doc(func_name, |
| py_str(desc.value), |
| arg_names, |
| arg_types, |
| [py_str(arg_descs[i]) for i in range(narg)], |
| key_var_num_args, |
| ret_type) |
| |
| dtype_name = None |
| arr_name = None |
| ndsignature = [] |
| signature = [] |
| ndarg_names = [] |
| kwarg_names = [] |
| for i in range(narg): |
| name, atype = arg_names[i], arg_types[i] |
| if name == 'dtype': |
| dtype_name = name |
| signature.append('%s=_Null'%name) |
| elif atype.startswith('NDArray') or atype.startswith('Symbol'): |
| assert not arr_name, \ |
| "Op can only have one argument with variable " \ |
| "size and it must be the last argument." |
| if atype.endswith('[]'): |
| ndsignature.append('*%s'%name) |
| arr_name = name |
| else: |
| ndsignature.append('%s=None'%name) |
| ndarg_names.append(name) |
| else: |
| signature.append('%s=_Null'%name) |
| kwarg_names.append(name) |
| #signature.append('is_train=False') |
| signature.append('name=None') |
| signature.append('attr=None') |
| signature.append('out=None') |
| signature.append('**kwargs') |
| signature = ndsignature + signature |
| |
| code = [] |
| if arr_name: |
| code.append(""" |
| def %s(*%s, **kwargs):"""%(func_name, arr_name)) |
| code.append(""" |
| sym_args = [] |
| for i in {}: |
| assert isinstance(i, SymbolBase), \\ |
| "Positional arguments must be Symbol instances, " \\ |
| "but got %s"%str(i) |
| sym_args.append(i)""".format(arr_name)) |
| if dtype_name is not None: |
| code.append(""" |
| if '%s' in kwargs: |
| kwargs['%s'] = _numpy.dtype(kwargs['%s']).name"""%( |
| dtype_name, dtype_name, dtype_name)) |
| code.append(""" |
| attr = kwargs.pop('attr', None) |
| kwargs.update(AttrScope.current.get(attr)) |
| name = kwargs.pop('name', None) |
| name = NameManager.current.get(name, '%s') |
| _ = kwargs.pop('out', None) |
| keys = [] |
| vals = [] |
| sym_kwargs = dict() |
| for k, v in kwargs.items(): |
| if isinstance(v, SymbolBase): |
| sym_kwargs[k] = v |
| else: |
| keys.append(k) |
| vals.append(v)"""%(func_name.lower())) |
| if key_var_num_args: |
| code.append(""" |
| if '%s' not in kwargs: |
| keys.append('%s') |
| vals.append(len(sym_args) + len(sym_kwargs))"""%( |
| key_var_num_args, key_var_num_args)) |
| |
| code.append(""" |
| return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name)"""%( |
| handle.value)) |
| else: |
| code.append(""" |
| def %s(%s): |
| kwargs.update(AttrScope.current.get(attr)) |
| sym_kwargs = dict() |
| keys = [] |
| vals = []"""%(func_name, ', '.join(signature))) |
| code.append(""" |
| for k, v in kwargs.items(): |
| if isinstance(v, SymbolBase): |
| sym_kwargs[k] = v |
| else: |
| keys.append(k) |
| vals.append(v)""") |
| # NDArray args |
| for name in ndarg_names: # pylint: disable=redefined-argument-from-local |
| code.append(""" |
| if {name} is not None: |
| assert isinstance({name}, SymbolBase), \\ |
| "Argument {name} must be Symbol instances, but got %s"%str({name}) |
| sym_kwargs['{name}'] = {name}""".format(name=name)) |
| # kwargs |
| for name in kwarg_names: # pylint: disable=redefined-argument-from-local |
| code.append(""" |
| if %s is not _Null: |
| keys.append('%s') |
| vals.append(%s)"""%(name, name, name)) |
| # dtype |
| if dtype_name is not None: |
| code.append(""" |
| if %s is not _Null: |
| keys.append('%s') |
| vals.append(_numpy.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name)) |
| |
| code.append(""" |
| name = NameManager.current.get(name, '%s') |
| return _symbol_creator(%d, None, sym_kwargs, keys, vals, name)"""%( |
| func_name.lower(), handle.value)) |
| |
| local = {} |
| exec(''.join(code), None, local) # pylint: disable=exec-used |
| symbol_function = local[func_name] |
| symbol_function.__name__ = func_name |
| symbol_function.__doc__ = doc_str |
| symbol_function.__module__ = 'mxnet.symbol' |
| return symbol_function |
| |
| |
| def _init_symbol_module(symbol_class, root_namespace): |
| """List and add all the atomic symbol functions to current module.""" |
| _set_symbol_class(symbol_class) |
| plist = ctypes.POINTER(ctypes.c_char_p)() |
| size = ctypes.c_uint() |
| |
| check_call(_LIB.MXListAllOpNames(ctypes.byref(size), |
| ctypes.byref(plist))) |
| op_names = [] |
| for i in range(size.value): |
| op_names.append(py_str(plist[i])) |
| |
| module_obj = _sys.modules["%s.symbol" % root_namespace] |
| module_internal = _sys.modules["%s._symbol_internal" % root_namespace] |
| module_contrib = _sys.modules["%s.contrib.symbol" % root_namespace] |
| for name in op_names: |
| hdl = OpHandle() |
| check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) |
| function = _make_atomic_symbol_function(hdl, name) |
| if function.__name__.startswith('_contrib_'): |
| function.__name__ = function.__name__[9:] |
| function.__module__ = 'mxnet.contrib.symbol' |
| setattr(module_contrib, function.__name__, function) |
| elif function.__name__.startswith('_'): |
| setattr(module_internal, function.__name__, function) |
| else: |
| setattr(module_obj, function.__name__, function) |
| |
| |
| # Initialize the atomic symbol in startups |
| _init_symbol_module(Symbol, "mxnet") |