blob: e8787fba77a3314c05a0d84523f29f041a198ed1 [file] [log] [blame]
from __future__ import absolute_import as _abs
import sys as _sys
import ctypes as _ctypes
import numpy as _numpy
from numbers import Number as _Number
from ..name import NameManager
from ..attribute import AttrScope
from ..symbol_doc import _build_doc
include "./base.pyi"
cdef class SymbolBase:
"""Symbol is symbolic graph."""
# handle for symbolic operator.
cdef SymbolHandle chandle
cdef _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None:
self.chandle = NULL
else:
ptr = handle.value
self.chandle = <SymbolHandle>(ptr)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return _ctypes.cast(<unsigned long long>self.chandle, _ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
def __dealloc__(self):
CALL(NNSymbolFree(self.chandle))
def _set_attr(self, **kwargs):
"""Set the attribute of the symbol.
Parameters
----------
**kwargs
The attributes to set
"""
SymbolSetAttr(self.chandle, kwargs)
def __reduce__(self):
return (_symbol_cls, (None,), self.__getstate__())
cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
cdef string sparam_key
cdef string sparam_val
cdef const char* param_key
cdef const char* param_val
for k, v in kwargs.items():
sparam_key = c_str(k)
sparam_val = c_str(str(v))
param_key = sparam_key.c_str()
param_val = sparam_val.c_str()
CALL(MXSymbolSetAttr(handle, param_key, param_val))
_symbol_cls = SymbolBase
def _set_symbol_class(cls):
global _symbol_cls
_symbol_cls = cls
cdef NewSymbol(SymbolHandle handle):
"""Create a new symbol given handle"""
sym = _symbol_cls(None)
(<SymbolBase>sym).chandle = handle
return sym
def invoke(cached_op, args, name=None):
cdef SymbolHandle ret
cdef vector[SymbolHandle] sym_args
hint = cached_op.op.lower()
cdef string cname = c_str(NameManager.current.get(name, hint))
for i in args:
sym_args.push_back((<SymbolBase>i).chandle)
CALL(MXCachedCreateSymbol(
(<CachedOp>cached_op).chandle,
cname.c_str(),
<int>len(args),
&sym_args[0] if sym_args.size() != 0 else NULL,
&ret))
return NewSymbol(ret)
def _symbol_creator(handle, args, kwargs, keys, vals, name):
cdef unsigned long long ihandle = handle
cdef OpHandle chandle = <OpHandle>ihandle
cdef vector[string] ckeys
cdef vector[string] cvals
cdef vector[string] sym_keys
cdef vector[SymbolHandle] sym_args
cdef SymbolHandle ret_handle
cdef string cname = c_str(name)
for i in keys:
ckeys.push_back(c_str(i))
for i in vals:
cvals.push_back(c_str(str(i)))
cdef vector[const char*] param_keys = SVec2Ptr(ckeys)
cdef vector[const char*] param_vals = SVec2Ptr(cvals)
CALL(MXSymbolCreateAtomicSymbol(
chandle,
<nn_uint>param_keys.size(),
CBeginPtr(param_keys),
CBeginPtr(param_vals),
&ret_handle))
if args and kwargs:
raise TypeError(
'Operators with variable length input can only accept input'
'Symbols either as positional or keyword arguments, not both')
if args:
for i in args:
sym_args.push_back((<SymbolBase>i).chandle)
elif kwargs:
for k, v in kwargs.items():
sym_keys.push_back(c_str(k))
sym_args.push_back((<SymbolBase>v).chandle)
cdef vector[const char*] csym_keys = SVec2Ptr(sym_keys)
CALL(NNSymbolCompose(
ret_handle,
cname.c_str(),
<nn_uint>sym_args.size(),
&csym_keys[0] if csym_keys.size() != 0 else NULL,
&sym_args[0] if sym_args.size() != 0 else NULL))
return NewSymbol(ret_handle)