blob: 0750212a8bb6ce62971543032e5fa70c17211162 [file] [log] [blame]
from __future__ import absolute_import as _abs
import sys as _sys
import ctypes as _ctypes
import numpy as _numpy
from numbers import Number as _Number
from ..name import NameManager
from ..attribute import AttrScope
from ..symbol_doc import _build_doc
include "./base.pyi"
cdef class SymbolBase:
"""Symbol is symbolic graph."""
# handle for symbolic operator.
cdef SymbolHandle chandle
cdef _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None:
self.chandle = NULL
else:
ptr = handle.value
self.chandle = <SymbolHandle>(ptr)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return _ctypes.cast(<unsigned long long>self.chandle, _ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
def __dealloc__(self):
CALL(NNSymbolFree(self.chandle))
def _set_attr(self, **kwargs):
"""Set the attribute of the symbol.
Parameters
----------
**kwargs
The attributes to set
"""
SymbolSetAttr(self.chandle, kwargs)
def __reduce__(self):
return (_symbol_cls, (None,), self.__getstate__())
cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
cdef string sparam_key
cdef string sparam_val
cdef const char* param_key
cdef const char* param_val
for k, v in kwargs.items():
sparam_key = c_str(k)
sparam_val = c_str(str(v))
param_key = sparam_key.c_str()
param_val = sparam_val.c_str()
CALL(MXSymbolSetAttr(handle, param_key, param_val))
_symbol_cls = SymbolBase
cdef _set_symbol_class(cls):
global _symbol_cls
_symbol_cls = cls
cdef NewSymbol(SymbolHandle handle):
"""Create a new symbol given handle"""
sym = _symbol_cls(None)
(<SymbolBase>sym).chandle = handle
return sym
cdef _make_atomic_symbol_function(OpHandle handle, string name):
"""Create an atomic symbol function by handle and funciton name."""
cdef const char *real_name
cdef const char *desc
cdef nn_uint num_args
cdef const char** arg_names
cdef const char** arg_types
cdef const char** arg_descs
cdef const char* return_type
cdef const char* key_var_num_args
CALL(MXSymbolGetAtomicSymbolInfo(
handle, &real_name, &desc,
&num_args, &arg_names,
&arg_types, &arg_descs,
&key_var_num_args, &return_type))
func_name = py_str(name.c_str())
key_vargs = py_str(key_var_num_args)
num_args = int(num_args)
doc_str = _build_doc(func_name,
py_str(desc),
[py_str(arg_names[i]) for i in range(num_args)],
[py_str(arg_types[i]) for i in range(num_args)],
[py_str(arg_descs[i]) for i in range(num_args)],
key_vargs,
py_str(return_type) if return_type != NULL else '')
func_hint = func_name.lower()
def creator(*args, **kwargs):
cdef vector[string] sparam_keys
cdef vector[string] sparam_vals
cdef vector[SymbolHandle] symbol_args
cdef vector[string] ssymbol_keys
cdef SymbolHandle ret_handle
attr = kwargs.pop("attr", None)
kwargs.update(AttrScope.current.get(attr))
name = kwargs.pop("name", None)
if key_vargs:
if key_vargs not in kwargs:
sparam_keys.push_back(c_str(key_vargs))
sparam_vals.push_back(c_str(str(len(args))))
if len(kwargs) != 0:
for k, v in kwargs.items():
if isinstance(v, SymbolBase):
ssymbol_keys.push_back(c_str(k))
symbol_args.push_back((<SymbolBase>v).chandle)
elif k == 'dtype':
sparam_keys.push_back(c_str(k))
sparam_vals.push_back(c_str(_numpy.dtype(v).name))
else:
sparam_keys.push_back(c_str(k))
sparam_vals.push_back(c_str(str(v)))
if len(args) != 0:
if symbol_args.size() != 0:
raise TypeError("compose only accept input Symbols\
either as positional or keyword arguments, not both")
for v in args:
if not isinstance(v, SymbolBase):
raise TypeError('Compose expect `Symbol` as arguments')
symbol_args.push_back((<SymbolBase>v).chandle)
cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys)
cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals)
cdef vector[const char*] symbol_keys = SVec2Ptr(ssymbol_keys)
CALL(MXSymbolCreateAtomicSymbol(
handle,
<nn_uint>param_keys.size(),
CBeginPtr(param_keys),
CBeginPtr(param_vals),
&ret_handle))
num_args = <nn_uint>(symbol_args.size())
name = NameManager.current.get(name, func_hint)
cdef const char* c_name = NULL
if name:
name = c_str(name)
c_name = name
CALL(NNSymbolCompose(
ret_handle,
c_name,
num_args,
&symbol_keys[0] if symbol_keys.size() != 0 else NULL,
&symbol_args[0] if symbol_args.size() != 0 else NULL))
return NewSymbol(ret_handle)
creator.__name__ = func_name
creator.__doc__ = doc_str
creator.__module__ = 'mxnet.symbol'
return creator
def _init_symbol_module(symbol_class, root_namespace):
"""List and add all the atomic symbol functions to current module."""
cdef const char** op_name_ptrs
cdef nn_uint size
cdef vector[string] op_names
cdef OpHandle handle
_set_symbol_class(symbol_class)
CALL(MXListAllOpNames(&size, &op_name_ptrs))
for i in range(size):
op_names.push_back(string(op_name_ptrs[i]))
module_obj = _sys.modules["%s.symbol" % root_namespace]
module_internal = _sys.modules["%s._symbol_internal" % root_namespace]
for i in range(op_names.size()):
CALL(NNGetOpHandle(op_names[i].c_str(), &handle))
function = _make_atomic_symbol_function(handle, op_names[i])
if function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)