| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file c_api.h |
| * \brief C API of mxnet |
| */ |
| #ifndef MXNET_C_API_H_ |
| #define MXNET_C_API_H_ |
| |
| /*! \brief Inhibit C++ name-mangling for MXNet functions. */ |
| #ifdef __cplusplus |
| extern "C" { |
| #endif // __cplusplus |
| |
| /*! \brief Keep the default value in C++ */ |
| #ifdef __cplusplus |
| #define DEFAULT(x) = x |
| #else |
| #define DEFAULT(x) |
| #endif // __cplusplus |
| |
| #include <stdint.h> |
| |
| #include <stdint.h> |
| #include <stddef.h> |
| #include <stdbool.h> |
| |
| /*! \brief MXNET_DLL prefix for windows */ |
| #ifdef _WIN32 |
| #ifdef MXNET_EXPORTS |
| #define MXNET_DLL __declspec(dllexport) |
| #else |
| #define MXNET_DLL __declspec(dllimport) |
| #endif |
| #else |
| #define MXNET_DLL |
| #endif |
| |
| /*! \brief manually define unsigned int */ |
| typedef uint32_t mx_uint; |
| /*! \brief manually define float */ |
| typedef float mx_float; |
| /*! \brief data type to store dim size */ |
| typedef int64_t dim_t; |
| // all the handles are simply void * |
| // will be casted internally to specific pointers types |
| // these typedefs are mainly used for readablity reasons |
| /*! \brief handle to NDArray */ |
| typedef void *NDArrayHandle; |
| /*! \brief handle to a mxnet narray function that changes NDArray */ |
| typedef const void *FunctionHandle; |
| /*! \brief handle to a function that takes param and creates symbol */ |
| typedef void *AtomicSymbolCreator; |
| /*! \brief handle to cached operator */ |
| typedef void *CachedOpHandle; |
| /*! \brief handle to a symbol that can be bind as operator */ |
| typedef void *SymbolHandle; |
| /*! \brief handle to a AtomicSymbol */ |
| typedef void *AtomicSymbolHandle; |
| /*! \brief handle to an Executor */ |
| typedef void *ExecutorHandle; |
| /*! \brief handle a dataiter creator */ |
| typedef void *DataIterCreator; |
| /*! \brief handle to a DataIterator */ |
| typedef void *DataIterHandle; |
| /*! \brief handle to KVStore */ |
| typedef void *KVStoreHandle; |
| /*! \brief handle to RecordIO */ |
| typedef void *RecordIOHandle; |
| /*! \brief handle to MXRtc*/ |
| typedef void *RtcHandle; |
| /*! \brief handle to rtc cuda module*/ |
| typedef void *CudaModuleHandle; |
| /*! \brief handle to rtc cuda kernel*/ |
| typedef void *CudaKernelHandle; |
| /*! \brief handle to a Profile object (domain, duration, counter, etc.) */ |
| typedef void *ProfileHandle; |
| /*! \brief handle to DLManagedTensor*/ |
| typedef void *DLManagedTensorHandle; |
| /*! \brief handle to Context */ |
| typedef const void *ContextHandle; |
| /*! \brief handle to Engine FnProperty */ |
| typedef const void *EngineFnPropertyHandle; |
| /*! \brief handle to Engine VarHandle */ |
| typedef void *EngineVarHandle; |
| |
| /*! \brief Engine asynchronous operation */ |
| typedef void (*EngineAsyncFunc)(void*, void*, void*); |
| /*! \brief Engine synchronous operation */ |
| typedef void (*EngineSyncFunc)(void*, void*); |
| /*! \brief Callback to free the param for EngineAsyncFunc/EngineSyncFunc */ |
| typedef void (*EngineFuncParamDeleter)(void*); |
| typedef void (*ExecutorMonitorCallback)(const char*, |
| NDArrayHandle, |
| void*); |
| /*! \brief Monitor callback called at operator level for cached op */ |
| typedef void (*CachedOpMonitorCallback)(const char*, |
| const char*, |
| NDArrayHandle); |
| |
| |
| struct NativeOpInfo { |
| void (*forward)(int, float**, int*, unsigned**, int*, void*); |
| void (*backward)(int, float**, int*, unsigned**, int*, void*); |
| void (*infer_shape)(int, int*, unsigned**, void*); |
| void (*list_outputs)(char***, void*); |
| void (*list_arguments)(char***, void*); |
| // all functions also pass a payload void* pointer |
| void* p_forward; |
| void* p_backward; |
| void* p_infer_shape; |
| void* p_list_outputs; |
| void* p_list_arguments; |
| }; |
| |
| struct NDArrayOpInfo { |
| bool (*forward)(int, void**, int*, void*); |
| bool (*backward)(int, void**, int*, void*); |
| bool (*infer_shape)(int, int*, unsigned**, void*); |
| bool (*list_outputs)(char***, void*); |
| bool (*list_arguments)(char***, void*); |
| bool (*declare_backward_dependency)(const int*, const int*, const int*, |
| int*, int**, void*); |
| // all functions also pass a payload void* pointer |
| void* p_forward; |
| void* p_backward; |
| void* p_infer_shape; |
| void* p_list_outputs; |
| void* p_list_arguments; |
| void* p_declare_backward_dependency; |
| }; |
| |
| typedef int (*MXGenericCallback)(void); |
| |
| struct MXCallbackList { |
| int num_callbacks; |
| int (**callbacks)(void); |
| void **contexts; |
| }; |
| |
| struct LibFeature { |
| const char* name; |
| bool enabled; |
| }; |
| |
| enum CustomOpCallbacks { |
| kCustomOpDelete, |
| kCustomOpForward, |
| kCustomOpBackward |
| }; |
| |
| enum CustomOpPropCallbacks { |
| kCustomOpPropDelete, |
| kCustomOpPropListArguments, |
| kCustomOpPropListOutputs, |
| kCustomOpPropListAuxiliaryStates, |
| kCustomOpPropInferShape, |
| kCustomOpPropDeclareBackwardDependency, |
| kCustomOpPropCreateOperator, |
| kCustomOpPropInferType, |
| kCustomOpPropInferStorageType, |
| kCustomOpPropBackwardInferStorageType |
| }; |
| |
| |
| typedef int (*CustomOpFBFunc)(int /*size*/, void** /*ptrs*/, int* /*tags*/, |
| const int* /*reqs*/, const int /*is_train*/, |
| void* /*state*/); |
| typedef int (*CustomOpDelFunc)(void* /*state*/); |
| typedef int (*CustomOpListFunc)(char*** /*args*/, void* /*state*/); |
| typedef int (*CustomOpInferShapeFunc)(int /*num_input*/, int* /*ndims*/, |
| int** /*shapes*/, void* /*state*/); |
| typedef int (*CustomOpInferStorageTypeFunc)(int /*num_input*/, int* /*stypes*/, void* /*state*/); |
| typedef int (*CustomOpBackwardInferStorageTypeFunc)(int /*num_input*/, |
| int * /*stypes*/, |
| int * /*tags*/, |
| void * /*state*/); |
| typedef int (*CustomOpInferTypeFunc)(int /*num_input*/, int* /*types*/, void* /*state*/); |
| typedef int (*CustomOpBwdDepFunc)(const int* /*out_grad*/, const int* /*in_data*/, |
| const int* /*out_data*/, int* /*num_deps*/, |
| int** /*rdeps*/, void* /*state*/); |
| typedef int (*CustomOpCreateFunc)(const char* /*ctx*/, int /*num_inputs*/, |
| unsigned** /*shapes*/, const int* /*ndims*/, |
| const int* /*dtypes*/, struct MXCallbackList* /*ret*/, |
| void* /*state*/); |
| typedef int (*CustomOpPropCreator)(const char* /*op_type*/, const int /*num_kwargs*/, |
| const char** /*keys*/, const char** /*values*/, |
| struct MXCallbackList* /*ret*/); |
| |
| |
| enum CustomFunctionCallbacks { |
| kCustomFunctionBackward, |
| kCustomFunctionDelete |
| }; |
| |
| typedef int (*CustomFunctionBwdFunc)(int /*num_ograds*/, int /*num_igrads*/, void** /*ptrs*/, |
| const int* /*reqs*/, const int /*is_train*/, |
| void* /*state*/); |
| typedef int (*CustomFunctionDelFunc)(void* /*state*/); |
| |
| /*! |
| * \brief return str message of the last error |
| * all function in this file will return 0 when success |
| * and -1 when an error occured, |
| * MXGetLastError can be called to retrieve the error |
| * |
| * this function is threadsafe and can be called by different thread |
| * \return error info |
| */ |
| MXNET_DLL const char *MXGetLastError(); |
| |
| //------------------------------------- |
| // Part 0: Global State setups |
| //------------------------------------- |
| |
| /*! |
| * \brief Load library dynamically |
| * \param path to the library .so file |
| * \param verbose 0 for quiet, 1 for verbose |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXLoadLib(const char *path, unsigned verbose); |
| |
| /*! |
| * \brief Get list of features supported on the runtime |
| * \param libFeature pointer to array of LibFeature |
| * \param size of the array |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXLibInfoFeatures(const struct LibFeature **libFeature, size_t *size); |
| |
| /*! |
| * \brief Seed all global random number generators in mxnet. |
| * \param seed the random number seed. |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXRandomSeed(int seed); |
| |
| /*! |
| * \brief Seed the global random number generator of the given device. |
| * \param seed the random number seed. |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXRandomSeedContext(int seed, int dev_type, int dev_id); |
| |
| /*! |
| * \brief Notify the engine about a shutdown, |
| * This can help engine to print less messages into display. |
| * |
| * User do not have to call this function. |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXNotifyShutdown(); |
| |
| /*! |
| * \brief Set up configuration of profiler for the process passed as profile_process in keys |
| * \param num_params Number of parameters |
| * \param keys array of parameter keys |
| * \param vals array of parameter values |
| * \param kvstoreHandle handle to kvstore |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXSetProcessProfilerConfig(int num_params, const char* const* keys, |
| const char* const* vals, |
| KVStoreHandle kvstoreHandle); |
| |
| /*! |
| * \brief Set up configuration of profiler for worker/current process |
| * \param num_params Number of parameters |
| * \param keys array of parameter keys |
| * \param vals array of parameter values |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals); |
| |
| /*! |
| * \brief Set up state of profiler for either worker or server process |
| * \param state indicate the working state of profiler, |
| * profiler not running when state == 0, |
| * profiler running when state == 1 |
| * \param profile_process an int, |
| * when 0 command is for worker/current process, |
| * when 1 command is for server process |
| * \param kvstoreHandle handle to kvstore, needed for server process profiling |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXSetProcessProfilerState(int state, int profile_process, |
| KVStoreHandle kvStoreHandle); |
| |
| /*! |
| * \brief Set up state of profiler for current process |
| * \param state indicate the working state of profiler, |
| * profiler not running when state == 0, |
| * profiler running when state == 1 |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXSetProfilerState(int state); |
| |
| /*! |
| * \brief Save profile and stop profiler |
| * \param finished true if stat output should stop after this point |
| * \param profile_process an int, |
| * when 0 command is for worker/current process, |
| * when 1 command is for server process |
| * \param kvstoreHandle handle to kvstore |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle); |
| |
| |
| /*! |
| * \brief Save profile and stop profiler for worker/current process |
| * \param finished true if stat output should stop after this point |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXDumpProfile(int finished); |
| |
| |
| /*! |
| * \brief Deprecated, use MXAggregateProfileStatsPrintEx instead. |
| * \param out_str Will receive a pointer to the output string |
| * \param reset Clear the aggregate stats after printing |
| * \return 0 when success, -1 when failure happens. |
| * \note |
| */ |
| MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset); |
| |
| /*! |
| * \brief Print sorted aggregate stats to the a string |
| * How aggregate stats are stored will not change |
| * \param out_str will receive a pointer to the output string |
| * \param reset clear the aggregate stats after printing |
| * \param format whether to return in tabular or json format |
| * \param sort_by sort by total, avg, min, max, or count |
| * \param ascending whether to sort ascendingly |
| * \return 0 when success, -1 when failure happens. |
| * \note |
| */ |
| MXNET_DLL int MXAggregateProfileStatsPrintEx(const char **out_str, int reset, int format, |
| int sort_by, int ascending); |
| |
| /*! |
| * \brief Pause profiler tuning collection |
| * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues |
| * \param profile_process integer which denotes whether to process worker or server process |
| * \param kvstoreHandle handle to kvstore |
| * \return 0 when success, -1 when failure happens. |
| * \note pausing and resuming is global and not recursive |
| */ |
| MXNET_DLL int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle); |
| |
| /*! |
| * \brief Pause profiler tuning collection for worker/current process |
| * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues |
| * \return 0 when success, -1 when failure happens. |
| * \note pausing and resuming is global and not recursive |
| */ |
| MXNET_DLL int MXProfilePause(int paused); |
| |
| /*! |
| * \brief Create profiling domain |
| * \param domain String representing the domain name to create |
| * \param out Return domain object |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXProfileCreateDomain(const char *domain, ProfileHandle *out); |
| |
| /*! |
| * \brief Create profile task |
| * \param name Name of the task |
| * \param domain Domain of the task |
| * \param out Output handle |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXProfileCreateTask(ProfileHandle domain, |
| const char *task_name, |
| ProfileHandle *out); |
| |
| /*! |
| * \brief Create profile frame |
| * \param name Name of the frame |
| * \param domain Domain of the frame |
| * \param out Output handle |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXProfileCreateFrame(ProfileHandle domain, |
| const char *frame_name, |
| ProfileHandle *out); |
| |
| /*! |
| * \brief Create profile event |
| * \param name Name of the event |
| * \param out Output handle |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXProfileCreateEvent(const char *event_name, ProfileHandle *out); |
| |
| /*! |
| * \brief Create profile counter |
| * \param name Name of the counter |
| * \param domain Domain of the counter |
| * \param out Output handle |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXProfileCreateCounter(ProfileHandle domain, |
| const char *counter_name, |
| ProfileHandle *out); |
| |
| /*! |
| * \brief Destroy a frame |
| * \param frame_handle Handle to frame to destroy |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXProfileDestroyHandle(ProfileHandle frame_handle); |
| |
| /*! |
| * \brief Start timing the duration of a profile duration object such as an event, task or frame |
| * \param duration_handle handle to the duration object |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXProfileDurationStart(ProfileHandle duration_handle); |
| |
| /*! |
| * \brief Stop timing the duration of a profile duration object such as an event, task or frame |
| * \param duration_handle handle to the duration object |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXProfileDurationStop(ProfileHandle duration_handle); |
| |
| /*! |
| * \brief Set a counter, given its handle |
| * \param counter_handle Handle to counter to set |
| * \param value Value to set the counter to (64-bit unsigned integer) |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXProfileSetCounter(ProfileHandle counter_handle, uint64_t value); |
| |
| /*! |
| * \brief Adjust a counter by the given amount, given its handle |
| * \param counter_handle Handle to counter to adjust |
| * \param value Value to adjust the counter by (64-bit signed integer) |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXProfileAdjustCounter(ProfileHandle counter_handle, int64_t value); |
| |
| /*! |
| * \brief Mark a single instant in time |
| * \param domain Domain of the marker |
| * \param instant_marker_name Name of the marker |
| * \param scope Scope of marker ('global', 'process', 'thread', 'task', 'marker') |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXProfileSetMarker(ProfileHandle domain, |
| const char *instant_marker_name, |
| const char *scope); |
| |
| /*! |
| * \brief Set the number of OMP threads to use |
| * \param thread_num Number of OMP threads desired |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXSetNumOMPThreads(int thread_num); |
| |
| /*! |
| * \brief set bulk execution limit |
| * \param bulk_size new bulk_size |
| * \param prev_bulk_size previous bulk_size |
| */ |
| MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size); |
| |
| /*! |
| * \brief Get the number of GPUs. |
| * \param pointer to int that will hold the number of GPUs available. |
| * \return 0 when success, -1 when failure happens. |
| */ |
| MXNET_DLL int MXGetGPUCount(int* out); |
| |
| /*! |
| * \brief get the free and total available memory on a GPU |
| * Note: Deprecated, use MXGetGPUMemoryInformation64 instead. |
| * \param dev the GPU number to query |
| * \param free_mem pointer to the integer holding free GPU memory |
| * \param total_mem pointer to the integer holding total GPU memory |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem); |
| |
| /*! |
| * \brief get the free and total available memory on a GPU |
| * \param dev the GPU number to query |
| * \param free_mem pointer to the uint64_t holding free GPU memory |
| * \param total_mem pointer to the uint64_t holding total GPU memory |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXGetGPUMemoryInformation64(int dev, uint64_t *free_mem, uint64_t *total_mem); |
| |
| /*! |
| * \brief get the MXNet library version as an integer |
| * \param pointer to the integer holding the version number |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXGetVersion(int *out); |
| |
| /*! |
| * \brief Load TVM operator from the binary library |
| * \param libpath TVM operators lib file |
| * \return 0 when success, -1 when failure happens |
| */ |
| #if MXNET_USE_TVM_OP |
| MXNET_DLL int MXLoadTVMOp(const char *libpath); |
| |
| struct OtherOptionEntity { |
| int val; |
| }; |
| |
| struct OtherOptionSpace { |
| OtherOptionEntity* entities; |
| int entities_size; |
| }; |
| |
| struct ConfigSpace { |
| int entity_map_size; |
| char** entity_map_key; |
| OtherOptionEntity* entity_map_val; |
| int space_map_size; |
| char** space_map_key; |
| OtherOptionSpace* space_map_val; |
| }; |
| |
| typedef struct ConfigSpaces { |
| int spaces_size; |
| char** spaces_key; |
| ConfigSpace* spaces_val; |
| } ConfigSpaces; |
| |
| MXNET_DLL int MXLoadTVMConfig(ConfigSpaces config); |
| #endif // MXNET_USE_TVM_OP |
| |
| |
| //------------------------------------- |
| // Part 1: NDArray creation and deletion |
| //------------------------------------- |
| /*! |
| * \brief create a NDArray handle that is not initialized |
| * can be used to pass in as mutate variables |
| * to hold the result of NDArray |
| * \param out the returning handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayCreateNone(NDArrayHandle *out); |
| /*! |
| * \brief create a NDArray with specified shape |
| * \param shape the pointer to the shape |
| * \param ndim the dimension of the shape |
| * \param dev_type device type, specify device we want to take |
| * \param dev_id the device id of the specific device |
| * \param delay_alloc whether to delay allocation until |
| * the narray is first mutated |
| * \param out the returning handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayCreate(const uint32_t *shape, |
| uint32_t ndim, |
| int dev_type, |
| int dev_id, |
| int delay_alloc, |
| NDArrayHandle *out); |
| |
| /*! |
| * \brief create a NDArray with specified shape and data type |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=0 (by default) |
| * \param shape the pointer to the shape |
| * \param ndim the dimension of the shape |
| * \param dev_type device type, specify device we want to take |
| * \param dev_id the device id of the specific device |
| * \param delay_alloc whether to delay allocation until |
| * the narray is first mutated |
| * \param dtype data type of created array |
| * \param out the returning handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayCreateEx(const uint32_t *shape, |
| uint32_t ndim, |
| int dev_type, |
| int dev_id, |
| int delay_alloc, |
| int dtype, |
| NDArrayHandle *out); |
| |
| /*! |
| * \brief create a NDArray with specified shape and data type |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support |
| * \param shape the pointer to int64_t shape |
| * \param ndim the dimension of the shape |
| * \param dev_type device type, specify device we want to take |
| * \param dev_id the device id of the specific device |
| * \param delay_alloc whether to delay allocation until |
| * the narray is first mutated |
| * \param dtype data type of created array |
| * \param out the returning handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayCreateEx64(const int64_t *shape, |
| int ndim, |
| int dev_type, |
| int dev_id, |
| int delay_alloc, |
| int dtype, |
| NDArrayHandle *out); |
| |
| /*! |
| * \brief create an empty sparse NDArray with specified shape and data type |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=0 (by default) |
| * \param storage_type the storage type of the ndarray |
| * \param shape the pointer to the shape |
| * \param ndim the dimension of the shape |
| * \param dev_type device type, specify device we want to take |
| * \param dev_id the device id of the specific device |
| * \param delay_alloc whether to delay allocation until |
| * the narray is first mutated |
| * \param dtype data type of created array |
| * \param num_aux the number of aux data to support this ndarray |
| * \param aux_type data type of the aux data for the created array |
| * \param aux_ndims the dimension of the shapes of aux data |
| * \param aux_shape the shapes of aux data |
| * \param out the returning handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type, |
| const uint32_t *shape, |
| uint32_t ndim, |
| int dev_type, |
| int dev_id, |
| int delay_alloc, |
| int dtype, |
| uint32_t num_aux, |
| int *aux_type, |
| uint32_t *aux_ndims, |
| const uint32_t *aux_shape, |
| NDArrayHandle *out); |
| |
| /*! |
| * \brief create an empty sparse NDArray with specified shape and data type |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support |
| * \param storage_type the storage type of the ndarray |
| * \param shape the pointer to the shape |
| * \param ndim the dimension of the shape |
| * \param dev_type device type, specify device we want to take |
| * \param dev_id the device id of the specific device |
| * \param delay_alloc whether to delay allocation until |
| * the narray is first mutated |
| * \param dtype data type of created array |
| * \param num_aux the number of aux data to support this ndarray |
| * \param aux_type data type of the aux data for the created array |
| * \param aux_ndims the dimension of the shapes of aux data |
| * \param aux_shape the shapes of aux data |
| * \param out the returning handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayCreateSparseEx64(int storage_type, |
| const int64_t *shape, |
| int ndim, |
| int dev_type, |
| int dev_id, |
| int delay_alloc, |
| int dtype, |
| uint32_t num_aux, |
| int *aux_type, |
| int *aux_ndims, |
| const int64_t *aux_shape, |
| NDArrayHandle *out); |
| |
| /*! |
| * \brief create a NDArray handle that is loaded from raw bytes. |
| * \param buf the head of the raw bytes |
| * \param size size of the raw bytes |
| * \param out the returning handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayLoadFromRawBytes(const void *buf, |
| size_t size, |
| NDArrayHandle *out); |
| /*! |
| * \brief save the NDArray into raw bytes. |
| * \param handle the NDArray handle |
| * \param out_size size of the raw bytes |
| * \param out_buf the head of returning memory bytes. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArraySaveRawBytes(NDArrayHandle handle, |
| size_t *out_size, |
| const char **out_buf); |
| /*! |
| * \brief Save list of narray into the file. |
| * \param fname name of the file. |
| * \param num_args number of arguments to save. |
| * \param args the array of NDArrayHandles to be saved. |
| * \param keys the name of the NDArray, optional, can be NULL |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArraySave(const char* fname, |
| uint32_t num_args, |
| NDArrayHandle* args, |
| const char** keys); |
| /*! |
| * \brief Load list of narray from the file. |
| * \param fname name of the file. |
| * \param out_size number of narray loaded. |
| * \param out_arr head of the returning narray handles. |
| * \param out_name_size size of output name arrray. |
| * \param out_names the names of returning NDArrays, can be NULL |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayLoad(const char* fname, |
| uint32_t *out_size, |
| NDArrayHandle** out_arr, |
| uint32_t *out_name_size, |
| const char*** out_names); |
| |
| /*! |
| * \brief Load list / dictionary of narrays from file content loaded into memory. |
| * This will load a list of ndarrays in a similar |
| * manner to MXNDArrayLoad, however, it loads from |
| * buffer containing the contents of a file, rather than |
| * from a specified file. |
| * \param ndarray_buffer pointer to the start of the ndarray file content |
| * \param size size of the file |
| * \param out_size number of narray loaded. |
| * \param out_arr head of the returning narray handles. |
| * \param out_name_size size of output name arrray. |
| * \param out_names the names of returning NDArrays, can be NULL |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer, |
| size_t size, |
| uint32_t *out_size, |
| NDArrayHandle** out_arr, |
| uint32_t *out_name_size, |
| const char*** out_names); |
| |
| /*! |
| * \brief Perform a synchronize copy from a contiguous CPU memory region. |
| * |
| * This function will call WaitToWrite before the copy is performed. |
| * This is useful to copy data from existing memory region that are |
| * not wrapped by NDArray(thus dependency not being tracked). |
| * |
| * \param handle the NDArray handle |
| * \param data the data source to copy from. |
| * \param size the memory size we want to copy from. |
| */ |
| MXNET_DLL int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, |
| const void *data, |
| size_t size); |
| /*! |
| * \brief Perform a synchronize copyto a contiguous CPU memory region. |
| * |
| * This function will call WaitToRead before the copy is performed. |
| * This is useful to copy data from existing memory region that are |
| * not wrapped by NDArray(thus dependency not being tracked). |
| * |
| * \param handle the NDArray handle |
| * \param data the data source to copy into. |
| * \param size the memory size we want to copy into. |
| */ |
| MXNET_DLL int MXNDArraySyncCopyToCPU(NDArrayHandle handle, |
| void *data, |
| size_t size); |
| |
| /*! |
| * \brief Copy src.data() to dst.data() if i = -1, else dst.aux_data(i) if i >= 0 |
| * This function blocks. Do not use it in performance critical code. |
| * \param handle_dst handle of a dst ndarray whose data/aux_data has been allocated |
| * \param handle_src handle of a src ndarray which has default storage type |
| * \param i dst data blob indicator |
| */ |
| MXNET_DLL int MXNDArraySyncCopyFromNDArray(NDArrayHandle handle_dst, |
| const NDArrayHandle handle_src, |
| const int i); |
| |
| /*! |
| * \brief check whether the NDArray format is valid |
| * \param full_check if `True`, rigorous check, O(N) operations |
| * Otherwise basic check, O(1) operations |
| */ |
| MXNET_DLL int MXNDArraySyncCheckFormat(NDArrayHandle handle, const bool full_check); |
| |
| /*! |
| * \brief Wait until all the pending writes with respect NDArray are finished. |
| * Always call this before read data out synchronizely. |
| * \param handle the NDArray handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayWaitToRead(NDArrayHandle handle); |
| |
| /*! |
| * \brief Wait until all the pending read/write with respect NDArray are finished. |
| * Always call this before write data into NDArray synchronizely. |
| * \param handle the NDArray handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayWaitToWrite(NDArrayHandle handle); |
| |
| /*! |
| * \brief wait until all delayed operations in |
| * the system is completed |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayWaitAll(); |
| |
| /*! |
| * \brief free the narray handle |
| * \param handle the handle to be freed |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayFree(NDArrayHandle handle); |
| |
| /*! |
| * \brief Slice the NDArray along axis 0. |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=0 (by default) |
| * \param handle the handle to the NDArray |
| * \param slice_begin The beginning index of slice |
| * \param slice_end The ending index of slice |
| * \param out The NDArrayHandle of sliced NDArray |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArraySlice(NDArrayHandle handle, |
| uint32_t slice_begin, |
| uint32_t slice_end, |
| NDArrayHandle *out); |
| |
| /*! |
| * \brief Slice the NDArray along axis 0. |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support |
| * \param handle the handle to the NDArray |
| * \param slice_begin The beginning index of slice |
| * \param slice_end The ending index of slice |
| * \param out The NDArrayHandle of sliced NDArray |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArraySlice64(NDArrayHandle handle, |
| int64_t slice_begin, |
| int64_t slice_end, |
| NDArrayHandle *out); |
| |
| /*! |
| * \brief Index the NDArray along axis 0. |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=0 (by default) |
| * \param handle the handle to the NDArray |
| * \param idx the index |
| * \param out The NDArrayHandle of output NDArray |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayAt(NDArrayHandle handle, |
| uint32_t idx, |
| NDArrayHandle *out); |
| |
| /*! |
| * \brief Index the NDArray along axis 0. |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support |
| * \param handle the handle to the NDArray |
| * \param idx the index |
| * \param out The NDArrayHandle of output NDArray |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayAt64(NDArrayHandle handle, |
| int64_t idx, |
| NDArrayHandle *out); |
| |
| /*! |
| * \brief get the storage type of the array |
| */ |
| MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle, |
| int *out_storage_type); |
| |
| /*! |
| * \brief Reshape the NDArray. |
| * \param handle the handle to the narray |
| * \param ndim number of dimensions of new shape |
| * \param dims new shape |
| * \param out the NDArrayHandle of reshaped NDArray |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle, |
| int ndim, |
| int *dims, |
| NDArrayHandle *out); |
| |
| /*! |
| * \brief Reshape the NDArray. |
| * \param handle the handle to the narray |
| * \param ndim number of dimensions of new shape |
| * \param dims new shape |
| * \param out the NDArrayHandle of reshaped NDArray |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle, |
| int ndim, |
| dim_t *dims, |
| bool reverse, |
| NDArrayHandle *out); |
| /*! |
| * \brief DEPRECATED. Use MXNDArrayGetShapeEx instead. |
| * get the shape of the array |
| * \param handle the handle to the narray |
| * \param out_dim the output dimension |
| * \param out_pdata pointer holder to get data pointer of the shape |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, |
| uint32_t *out_dim, |
| const uint32_t **out_pdata); |
| |
| /*! |
| * \brief get the shape of the array |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=0 (by default) |
| * \param handle the handle to the narray |
| * \param out_dim the output dimension |
| * \param out_pdata pointer holder to get data pointer of the shape |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle, |
| int *out_dim, |
| const int **out_pdata); |
| |
| /*! |
| * \brief get the shape of the array |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support |
| * \param handle the handle to the narray |
| * \param out_dim the output dimension |
| * \param out_pdata pointer holder to get data pointer of the shape |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayGetShapeEx64(NDArrayHandle handle, |
| int *out_dim, |
| const int64_t **out_pdata); |
| |
| /*! |
| * \brief get the content of the data in NDArray |
| * \param handle the handle to the ndarray |
| * \param out_pdata pointer holder to get pointer of data |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle, |
| void **out_pdata); |
| /*! |
| * \brief Create a reference view of NDArray that |
| * represents as DLManagedTensor |
| * Notice: MXNet uses asynchronous execution. Please call MXNDArrayWaitToRead or |
| * MXNDArrayWaitToWrite before calling MXNDArrayToDLPack. |
| * \param handle the handle to the ndarray |
| * \param out_dlpack pointer holder to get pointer of DLManagedTensor |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayToDLPack(NDArrayHandle handle, |
| DLManagedTensorHandle *out_dlpack); |
| |
| /*! |
| * \brief DEPRECATED. Use MXNDArrayFromDLPackEx instead. |
| |
| * |
| * This allows us to create a NDArray using the memory |
| * allocated by an external deep learning framework |
| * that is DLPack compatible. |
| * |
| * The memory is retained until the NDArray went out of scope. |
| * |
| * \param dlpack the pointer of the input DLManagedTensor |
| * \param transient_handle whether the handle will be destructed before calling the deleter |
| * \param out_handle pointer holder to get pointer of NDArray |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack, |
| NDArrayHandle *out_handle); |
| |
| /*! |
| * \brief Create a NDArray backed by a dlpack tensor. |
| * |
| * This allows us to create a NDArray using the memory |
| * allocated by an external deep learning framework |
| * that is DLPack compatible. |
| * |
| * The memory is retained until the NDArray went out of scope. |
| * |
| * \param dlpack the pointer of the input DLManagedTensor |
| * \param transient_handle whether the handle will be destructed before calling the deleter |
| * \param out_handle pointer holder to get pointer of NDArray |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayFromDLPackEx(DLManagedTensorHandle dlpack, |
| const bool transient_handle, |
| NDArrayHandle *out_handle); |
| |
| /*! |
| * \brief Delete a dlpack tensor |
| * \param dlpack the pointer of the input DLManagedTensor |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack); |
| |
| /*! |
| * \brief get the type of the data in NDArray |
| * \param handle the handle to the narray |
| * \param out_dtype pointer holder to get type of data |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle, |
| int *out_dtype); |
| |
| /*! |
| * \brief get the type of the ith aux data in NDArray |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=0 (by default) |
| * \param handle the handle to the narray |
| * \param i the index of the aux data |
| * \param out_type pointer holder to get type of aux data |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle, |
| uint32_t i, |
| int *out_type); |
| |
| /*! |
| * \brief get the type of the ith aux data in NDArray |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support |
| * \param handle the handle to the narray |
| * \param i the index of the aux data |
| * \param out_type pointer holder to get type of aux data |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayGetAuxType64(NDArrayHandle handle, |
| int64_t i, |
| int *out_type); |
| |
| /*! |
| * \brief Get a deep copy of the ith aux data blob |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=0 (by default) |
| * in the form of an NDArray of default storage type. |
| * This function blocks. Do not use it in performance critical code. |
| */ |
| MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle, |
| uint32_t i, |
| NDArrayHandle *out); |
| |
| /*! |
| * \brief Get a deep copy of the ith aux data blob |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support |
| * in the form of an NDArray of default storage type. |
| * This function blocks. Do not use it in performance critical code. |
| */ |
| MXNET_DLL int MXNDArrayGetAuxNDArray64(NDArrayHandle handle, |
| int64_t i, |
| NDArrayHandle *out); |
| |
| /*! |
| * \brief Get a deep copy of the data blob |
| * in the form of an NDArray of default storage type. |
| * This function blocks. Do not use it in performance critical code. |
| */ |
| MXNET_DLL int MXNDArrayGetDataNDArray(NDArrayHandle handle, |
| NDArrayHandle *out); |
| /*! |
| * \brief get the context of the NDArray |
| * \param handle the handle to the narray |
| * \param out_dev_type the output device type |
| * \param out_dev_id the output device id |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayGetContext(NDArrayHandle handle, |
| int *out_dev_type, |
| int *out_dev_id); |
| /*! |
| * \brief return gradient buffer attached to this NDArray |
| * \param handle NDArray handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out); |
| /*! |
| * \brief detach and ndarray from computation graph by clearing entry_ |
| * \param handle NDArray handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out); |
| /*! |
| * \brief set the flag for gradient array state. |
| * \param handle NDArray handle |
| * \param state the new state. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArraySetGradState(NDArrayHandle handle, int state); |
| /*! |
| * \brief set the flag for gradient array state. |
| * \param handle NDArray handle |
| * \param state the new state. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out); |
| //-------------------------------- |
| // Part 2: functions on NDArray |
| //-------------------------------- |
| /*! |
| * \brief list all the available functions handles |
| * most user can use it to list all the needed functions |
| * \param out_size the size of returned array |
| * \param out_array the output function array |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXListFunctions(uint32_t *out_size, |
| FunctionHandle **out_array); |
| |
| /*! |
| * \brief get the function handle by name |
| * \param name the name of the function |
| * \param out the corresponding function handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXGetFunction(const char *name, |
| FunctionHandle *out); |
| /*! |
| * \brief Get the information of the function handle. |
| * \param fun The function handle. |
| * \param name The returned name of the function. |
| * \param description The returned description of the function. |
| * \param num_args Number of arguments. |
| * \param arg_names Name of the arguments. |
| * \param arg_type_infos Type information about the arguments. |
| * \param arg_descriptions Description information about the arguments. |
| * \param return_type Return type of the function. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXFuncGetInfo(FunctionHandle fun, |
| const char **name, |
| const char **description, |
| uint32_t *num_args, |
| const char ***arg_names, |
| const char ***arg_type_infos, |
| const char ***arg_descriptions, |
| const char **return_type DEFAULT(NULL)); |
| /*! |
| * \brief get the argument requirements of the function |
| * \param fun input function handle |
| * \param num_use_vars how many NDArrays to be passed in as used_vars |
| * \param num_scalars scalar variable is needed |
| * \param num_mutate_vars how many NDArrays to be passed in as mutate_vars |
| * \param type_mask the type mask of this function |
| * \return 0 when success, -1 when failure happens |
| * \sa MXFuncInvoke |
| */ |
| MXNET_DLL int MXFuncDescribe(FunctionHandle fun, |
| uint32_t *num_use_vars, |
| uint32_t *num_scalars, |
| uint32_t *num_mutate_vars, |
| int *type_mask); |
| /*! |
| * \brief invoke a function, the array size of passed in arguments |
| * must match the values in the |
| * \param fun the function |
| * \param use_vars the normal arguments passed to function |
| * \param scalar_args the scalar qarguments |
| * \param mutate_vars the mutate arguments |
| * \return 0 when success, -1 when failure happens |
| * \sa MXFuncDescribeArgs |
| */ |
| MXNET_DLL int MXFuncInvoke(FunctionHandle fun, |
| NDArrayHandle *use_vars, |
| float *scalar_args, |
| NDArrayHandle *mutate_vars); |
| /*! |
| * \brief invoke a function, the array size of passed in arguments |
| * must match the values in the |
| * \param fun the function |
| * \param use_vars the normal arguments passed to function |
| * \param scalar_args the scalar qarguments |
| * \param mutate_vars the mutate arguments |
| * \param num_params number of keyword parameters |
| * \param param_keys keys for keyword parameters |
| * \param param_vals values for keyword parameters |
| * \return 0 when success, -1 when failure happens |
| * \sa MXFuncDescribeArgs |
| */ |
| MXNET_DLL int MXFuncInvokeEx(FunctionHandle fun, |
| NDArrayHandle *use_vars, |
| float *scalar_args, |
| NDArrayHandle *mutate_vars, |
| int num_params, |
| char **param_keys, |
| char **param_vals); |
| /*! |
| * \brief invoke a nnvm op and imperative function |
| * \param creator the op |
| * \param num_inputs number of input NDArrays |
| * \param inputs input NDArrays |
| * \param num_outputs number of output NDArrays |
| * \param outputs output NDArrays |
| * \param num_params number of keyword parameters |
| * \param param_keys keys for keyword parameters |
| * \param param_vals values for keyword parameters |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXImperativeInvoke(AtomicSymbolCreator creator, |
| int num_inputs, |
| NDArrayHandle *inputs, |
| int *num_outputs, |
| NDArrayHandle **outputs, |
| int num_params, |
| const char **param_keys, |
| const char **param_vals); |
| /*! |
| * \brief invoke a nnvm op and imperative function |
| * \param creator the op |
| * \param num_inputs number of input NDArrays |
| * \param inputs input NDArrays |
| * \param num_outputs number of output NDArrays |
| * \param outputs output NDArrays |
| * \param num_params number of keyword parameters |
| * \param param_keys keys for keyword parameters |
| * \param param_vals values for keyword parameters |
| * \param out_stypes output ndarrays' stypes |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXImperativeInvokeEx(AtomicSymbolCreator creator, |
| int num_inputs, |
| NDArrayHandle *inputs, |
| int *num_outputs, |
| NDArrayHandle **outputs, |
| int num_params, |
| const char **param_keys, |
| const char **param_vals, |
| const int **out_stypes); |
| /*! |
| * \brief set whether to record operator for autograd |
| * \param is_recording 1 when recording, 0 when not recording. |
| * \param prev returns the previous status before this set. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXAutogradSetIsRecording(int is_recording, int* prev); |
| /*! |
| * \brief set whether to record operator for autograd |
| * \param is_training 1 when training, 0 when testing |
| * \param prev returns the previous status before this set. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXAutogradSetIsTraining(int is_training, int* prev); |
| /*! |
| * \brief get whether autograd recording is on |
| * \param curr returns the current status. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXAutogradIsRecording(bool* curr); |
| /*! |
| * \brief get whether training mode is on |
| * \param curr returns the current status. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXAutogradIsTraining(bool* curr); |
| /*! |
| * \brief get whether numpy compatibility is on |
| * \param curr returns the current status |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXIsNumpyShape(int* curr); |
| /*! |
| * \brief set numpy compatibility switch |
| * \param is_np_shape 1 when numpy shape semantics is thread local on, |
| * 2 when numpy shape semantics is global on and 0 when off |
| * \param prev returns the previous status before this set |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSetIsNumpyShape(int is_np_shape, int* prev); |
| /*! |
| * \brief mark NDArrays as variables to compute gradient for autograd |
| * \param num_var number of variable NDArrays |
| * \param var_handles variable NDArrays |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var, |
| NDArrayHandle *var_handles, |
| uint32_t *reqs_array, |
| NDArrayHandle *grad_handles); |
| /*! |
| * \brief compute the gradient of outputs w.r.t variabels |
| * \param num_output number of output NDArray |
| * \param output_handles output NDArrays |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXAutogradComputeGradient(uint32_t num_output, |
| NDArrayHandle* output_handles); |
| /*! |
| * \brief compute the gradient of outputs w.r.t variabels |
| * \param num_output number of output NDArray |
| * \param output_handles output NDArrays |
| * \param ograd_handles head gradient for NDArrays |
| * \param retain_graph whether to keep the graph after backward |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXAutogradBackward(uint32_t num_output, |
| NDArrayHandle* output_handles, |
| NDArrayHandle* ograd_handles, |
| int retain_graph); |
| /*! |
| * \brief compute the gradient of outputs w.r.t variabels |
| * \param num_output number of output NDArray |
| * \param output_handles output NDArrays |
| * \param ograd_handles head gradient for NDArrays |
| * \param num_variables number of variables |
| * \param |
| * \param retain_graph whether to keep the graph after backward |
| * \param is_train whether to do backward for training or inference |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXAutogradBackwardEx(uint32_t num_output, |
| NDArrayHandle *output_handles, |
| NDArrayHandle *ograd_handles, |
| uint32_t num_variables, |
| NDArrayHandle *var_handles, |
| int retain_graph, |
| int create_graph, |
| int is_train, |
| NDArrayHandle **grad_handles, |
| int **grad_stypes); |
| /* |
| * \brief get the graph constructed by autograd. |
| * \param handle ndarray handle |
| * \param out output symbol handle |
| */ |
| MXNET_DLL int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out); |
| /*! |
| * \brief create cached operator |
| */ |
| MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out); |
| /*! |
| * \brief create cached operator |
| */ |
| MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle, |
| int num_flags, |
| const char** keys, |
| const char** vals, |
| CachedOpHandle *out); |
| |
| /*! |
| * \brief create cached operator, allows to choose thread_safe version |
| * of cachedop |
| */ |
| MXNET_DLL int MXCreateCachedOpEX(SymbolHandle handle, |
| int num_flags, |
| const char** keys, |
| const char** vals, |
| CachedOpHandle *out, |
| bool thread_safe DEFAULT(false)); |
| |
| /*! |
| * \brief free cached operator |
| */ |
| MXNET_DLL int MXFreeCachedOp(CachedOpHandle handle); |
| |
| /*! |
| * \brief invoke cached operator |
| */ |
| MXNET_DLL int MXInvokeCachedOp(CachedOpHandle handle, |
| int num_inputs, |
| NDArrayHandle *inputs, |
| int *num_outputs, |
| NDArrayHandle **outputs); |
| |
| /*! |
| * \brief invoke a cached op |
| * \param handle the handle to the cached op |
| * \param num_inputs number of input NDArrays |
| * \param inputs input NDArrays |
| * \param num_outputs number of output NDArrays |
| * \param outputs output NDArrays |
| * \param out_stypes output ndarrays' stypes |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle, |
| int num_inputs, |
| NDArrayHandle *inputs, |
| int *num_outputs, |
| NDArrayHandle **outputs, |
| const int** out_stypes); |
| |
| /*! |
| * \brief cached op set monitor callback |
| */ |
| MXNET_DLL int MXCachedOpRegisterOpHook(NDArrayHandle handle, |
| CachedOpMonitorCallback callback, |
| bool monitor_all); |
| |
| //-------------------------------------------- |
| // Part 3: symbolic configuration generation |
| //-------------------------------------------- |
| /*! |
| * \brief list all the available operator names, include entries. |
| * \param out_size the size of returned array |
| * \param out_array the output operator name array. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXListAllOpNames(uint32_t *out_size, |
| const char ***out_array); |
| |
| /*! |
| * \brief list all the available AtomicSymbolEntry |
| * \param out_size the size of returned array |
| * \param out_array the output AtomicSymbolCreator array |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolListAtomicSymbolCreators(uint32_t *out_size, |
| AtomicSymbolCreator **out_array); |
| |
| /*! |
| * \brief Get the name of an atomic symbol. |
| * \param creator the AtomicSymbolCreator. |
| * \param name The returned name of the creator. |
| */ |
| MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, |
| const char **name); |
| |
| /*! |
| * \brief Get the input symbols of the graph. |
| * \param sym The graph. |
| * \param inputs The input symbols of the graph. |
| * \param input_size the number of input symbols returned. |
| */ |
| MXNET_DLL int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **inputs, |
| int *input_size); |
| |
| /*! |
| * \brief Cut a subgraph whose nodes are marked with a subgraph attribute. |
| * The input graph will be modified. A variable node will be created for each |
| * edge that connects to nodes outside the subgraph. The outside nodes that |
| * connect to the subgraph will be returned. |
| * \param sym The graph. |
| * \param inputs The nodes that connect to the subgraph. |
| * \param input_size The number of such nodes. |
| */ |
| MXNET_DLL int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **inputs, |
| int *input_size); |
| |
| /*! |
| * \brief Get the detailed information about atomic symbol. |
| * \param creator the AtomicSymbolCreator. |
| * \param name The returned name of the creator. |
| * \param description The returned description of the symbol. |
| * \param num_args Number of arguments. |
| * \param arg_names Name of the arguments. |
| * \param arg_type_infos Type informations about the arguments. |
| * \param arg_descriptions Description information about the arguments. |
| * \param key_var_num_args The keyword argument for specifying variable number of arguments. |
| * When this parameter has non-zero length, the function allows variable number |
| * of positional arguments, and will need the caller to pass it in in |
| * MXSymbolCreateAtomicSymbol, |
| * With key = key_var_num_args, and value = number of positional arguments. |
| * \param return_type Return type of the function, can be Symbol or Symbol[] |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, |
| const char **name, |
| const char **description, |
| uint32_t *num_args, |
| const char ***arg_names, |
| const char ***arg_type_infos, |
| const char ***arg_descriptions, |
| const char **key_var_num_args, |
| const char **return_type DEFAULT(NULL)); |
| /*! |
| * \brief Create an AtomicSymbol. |
| * \param creator the AtomicSymbolCreator |
| * \param num_param the number of parameters |
| * \param keys the keys to the params |
| * \param vals the vals of the params |
| * \param out pointer to the created symbol handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, |
| uint32_t num_param, |
| const char **keys, |
| const char **vals, |
| SymbolHandle *out); |
| /*! |
| * \brief Create a Variable Symbol. |
| * \param name name of the variable |
| * \param out pointer to the created symbol handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolCreateVariable(const char *name, SymbolHandle *out); |
| /*! |
| * \brief Create a Symbol by grouping list of symbols together |
| * \param num_symbols number of symbols to be grouped |
| * \param symbols array of symbol handles |
| * \param out pointer to the created symbol handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolCreateGroup(uint32_t num_symbols, |
| SymbolHandle *symbols, |
| SymbolHandle *out); |
| /*! |
| * \brief Load a symbol from a json file. |
| * \param fname the file name. |
| * \param out the output symbol. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out); |
| /*! |
| * \brief Load a symbol from a json string. |
| * \param json the json string. |
| * \param out the output symbol. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out); |
| /*! |
| * \brief Remove the operators amp_cast and amp_multicast |
| * \param sym_handle the input symbol. |
| * \param ret_sym_handle the output symbol. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolRemoveAmpCast(SymbolHandle sym_handle, SymbolHandle* ret_sym_handle); |
| /*! |
| * \brief Save a symbol into a json file. |
| * \param symbol the input symbol. |
| * \param fname the file name. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname); |
| /*! |
| * \brief Save a symbol into a json string |
| * \param symbol the input symbol. |
| * \param out_json output json string. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json); |
| /*! |
| * \brief Free the symbol handle. |
| * \param symbol the symbol |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolFree(SymbolHandle symbol); |
| /*! |
| * \brief Copy the symbol to another handle |
| * \param symbol the source symbol |
| * \param out used to hold the result of copy |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out); |
| /*! |
| * \brief Print the content of symbol, used for debug. |
| * \param symbol the symbol |
| * \param out_str pointer to hold the output string of the printing. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolPrint(SymbolHandle symbol, const char **out_str); |
| /*! |
| * \brief Get string name from symbol |
| * \param symbol the source symbol |
| * \param out The result name. |
| * \param success Whether the result is contained in out. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolGetName(SymbolHandle symbol, |
| const char** out, |
| int *success); |
| /*! |
| * \brief Get string attribute from symbol |
| * \param symbol the source symbol |
| * \param key The key of the symbol. |
| * \param out The result attribute, can be NULL if the attribute do not exist. |
| * \param success Whether the result is contained in out. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolGetAttr(SymbolHandle symbol, |
| const char* key, |
| const char** out, |
| int *success); |
| /*! |
| * \brief Set string attribute from symbol. |
| * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph. |
| * |
| * Safe recommendaton: use immutable graph |
| * - Only allow set attributes during creation of new symbol as optional parameter |
| * |
| * Mutable graph (be careful about the semantics): |
| * - Allow set attr at any point. |
| * - Mutating an attribute of some common node of two graphs can cause confusion from user. |
| * |
| * \param symbol the source symbol |
| * \param key The key of the symbol. |
| * \param value The value to be saved. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolSetAttr(SymbolHandle symbol, |
| const char* key, |
| const char* value); |
| /*! |
| * \brief Get all attributes from symbol, including all descendents. |
| * \param symbol the source symbol |
| * \param out_size The number of output attributes |
| * \param out 2*out_size strings representing key value pairs. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolListAttr(SymbolHandle symbol, |
| uint32_t *out_size, |
| const char*** out); |
| /*! |
| * \brief Get all attributes from symbol, excluding descendents. |
| * \param symbol the source symbol |
| * \param out_size The number of output attributes |
| * \param out 2*out_size strings representing key value pairs. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol, |
| uint32_t *out_size, |
| const char*** out); |
| /*! |
| * \brief List arguments in the symbol. |
| * \param symbol the symbol |
| * \param out_size output size |
| * \param out_str_array pointer to hold the output string array |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol, |
| uint32_t *out_size, |
| const char ***out_str_array); |
| |
| /*! |
| * \brief List returns in the symbol. |
| * \param symbol the symbol |
| * \param out_size output size |
| * \param out_str_array pointer to hold the output string array |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol, |
| uint32_t *out_size, |
| const char ***out_str_array); |
| |
| /*! |
| * \brief Get number of outputs of the symbol. |
| * \param symbol The symbol |
| * \param out_size number of outputs |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol, |
| uint32_t *output_count); |
| |
| /*! |
| * \brief Get a symbol that contains all the internals. |
| * \param symbol The symbol |
| * \param out The output symbol whose outputs are all the internals. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolGetInternals(SymbolHandle symbol, |
| SymbolHandle *out); |
| /*! |
| * \brief Get a symbol that contains only direct children. |
| * \param symbol The symbol |
| * \param out The output symbol whose outputs are the direct children. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolGetChildren(SymbolHandle symbol, |
| SymbolHandle *out); |
| /*! |
| * \brief Get index-th outputs of the symbol. |
| * \param symbol The symbol |
| * \param index the Index of the output. |
| * \param out The output symbol whose outputs are the index-th symbol. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol, |
| uint32_t index, |
| SymbolHandle *out); |
| |
| /*! |
| * \brief List auxiliary states in the symbol. |
| * \param symbol the symbol |
| * \param out_size output size |
| * \param out_str_array pointer to hold the output string array |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol, |
| uint32_t *out_size, |
| const char ***out_str_array); |
| |
| /*! |
| * \brief Compose the symbol on other symbols. |
| * |
| * This function will change the sym hanlde. |
| * To achieve function apply behavior, copy the symbol first |
| * before apply. |
| * |
| * \param sym the symbol to apply |
| * \param name the name of symbol |
| * \param num_args number of arguments |
| * \param keys the key of keyword args (optional) |
| * \param args arguments to sym |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolCompose(SymbolHandle sym, |
| const char *name, |
| uint32_t num_args, |
| const char** keys, |
| SymbolHandle* args); |
| /*! |
| * \brief Get the gradient graph of the symbol |
| * |
| * \param sym the symbol to get gradient |
| * \param num_wrt number of arguments to get gradient |
| * \param wrt the name of the arguments to get gradient |
| * \param out the returned symbol that has gradient |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolGrad(SymbolHandle sym, |
| uint32_t num_wrt, |
| const char** wrt, |
| SymbolHandle* out); |
| /*! |
| * \brief DEPRECATED. Use MXSymbolInferShapeEx instead. |
| * infer shape of unknown input shapes given the known one. |
| * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data |
| * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. |
| * |
| * \param sym symbol handle |
| * \param num_args numbe of input arguments. |
| * \param keys the key of keyword args (optional) |
| * \param arg_ind_ptr the head pointer of the rows in CSR |
| * \param arg_shape_data the content of the CSR |
| * \param in_shape_size sizeof the returning array of in_shapes |
| * \param in_shape_ndim returning array of shape dimensions of each input shape. |
| * \param in_shape_data returning array of pointers to head of the input shape. |
| * \param out_shape_size sizeof the returning array of out_shapes |
| * \param out_shape_ndim returning array of shape dimensions of each output shape. |
| * \param out_shape_data returning array of pointers to head of the output shape. |
| * \param aux_shape_size sizeof the returning array of aux_shapes |
| * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape. |
| * \param aux_shape_data returning array of pointers to head of the auxiliary shape. |
| * \param complete whether infer shape completes or more information is needed. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolInferShape(SymbolHandle sym, |
| uint32_t num_args, |
| const char** keys, |
| const uint32_t *arg_ind_ptr, |
| const uint32_t *arg_shape_data, |
| uint32_t *in_shape_size, |
| const uint32_t **in_shape_ndim, |
| const uint32_t ***in_shape_data, |
| uint32_t *out_shape_size, |
| const uint32_t **out_shape_ndim, |
| const uint32_t ***out_shape_data, |
| uint32_t *aux_shape_size, |
| const uint32_t **aux_shape_ndim, |
| const uint32_t ***aux_shape_data, |
| int *complete); |
| |
| /*! |
| * \brief infer shape of unknown input shapes given the known one. |
| * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data |
| * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=0 (by default) |
| * \param sym symbol handle |
| * \param num_args number of input arguments. |
| * \param keys the key of keyword args (optional) |
| * \param arg_ind_ptr the head pointer of the rows in CSR |
| * \param arg_shape_data the content of the CSR |
| * \param in_shape_size sizeof the returning array of in_shapes |
| * \param in_shape_ndim returning array of shape dimensions of eachs input shape. |
| * \param in_shape_data returning array of pointers to head of the input shape. |
| * \param out_shape_size sizeof the returning array of out_shapes |
| * \param out_shape_ndim returning array of shape dimensions of each output shape. |
| * \param out_shape_data returning array of pointers to head of the output shape. |
| * \param aux_shape_size sizeof the returning array of aux_shapes |
| * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape. |
| * \param aux_shape_data returning array of pointers to head of the auxiliary shape. |
| * \param complete whether infer shape completes or more information is needed. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym, |
| uint32_t num_args, |
| const char** keys, |
| const uint32_t *arg_ind_ptr, |
| const int *arg_shape_data, |
| uint32_t *in_shape_size, |
| const int **in_shape_ndim, |
| const int ***in_shape_data, |
| uint32_t *out_shape_size, |
| const int **out_shape_ndim, |
| const int ***out_shape_data, |
| uint32_t *aux_shape_size, |
| const int **aux_shape_ndim, |
| const int ***aux_shape_data, |
| int *complete); |
| |
| /*! |
| * \brief infer shape of unknown input shapes given the known one. |
| * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data |
| * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support |
| * \param sym symbol handle |
| * \param num_args number of input arguments. |
| * \param keys the key of keyword args (optional) |
| * \param arg_ind_ptr the head pointer of the rows in CSR |
| * \param arg_shape_data the content of the CSR |
| * \param in_shape_size sizeof the returning array of in_shapes |
| * \param in_shape_ndim returning array of shape dimensions of each input shape. |
| * \param in_shape_data returning array of pointers to head of the input shape. |
| * \param out_shape_size sizeof the returning array of out_shapes |
| * \param out_shape_ndim returning array of shape dimensions of each output shape. |
| * \param out_shape_data returning array of pointers to head of the output shape. |
| * \param aux_shape_size sizeof the returning array of aux_shapes |
| * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape. |
| * \param aux_shape_data returning array of pointers to head of the auxiliary shape. |
| * \param complete whether infer shape completes or more information is needed. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolInferShapeEx64(SymbolHandle sym, |
| uint32_t num_args, |
| const char** keys, |
| const int64_t *arg_ind_ptr, |
| const int64_t *arg_shape_data, |
| size_t *in_shape_size, |
| const int **in_shape_ndim, |
| const int64_t ***in_shape_data, |
| size_t *out_shape_size, |
| const int **out_shape_ndim, |
| const int64_t ***out_shape_data, |
| size_t *aux_shape_size, |
| const int **aux_shape_ndim, |
| const int64_t ***aux_shape_data, |
| int *complete); |
| |
| /*! |
| * \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead. |
| * partially infer shape of unknown input shapes given the known one. |
| * |
| * Return partially inferred results if not all shapes could be inferred. |
| * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data |
| * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. |
| * |
| * \param sym symbol handle |
| * \param num_args numbe of input arguments. |
| * \param keys the key of keyword args (optional) |
| * \param arg_ind_ptr the head pointer of the rows in CSR |
| * \param arg_shape_data the content of the CSR |
| * \param in_shape_size sizeof the returning array of in_shapes |
| * \param in_shape_ndim returning array of shape dimensions of each input shape. |
| * \param in_shape_data returning array of pointers to head of the input shape. |
| * \param out_shape_size sizeof the returning array of out_shapes |
| * \param out_shape_ndim returning array of shape dimensions of each output shape. |
| * \param out_shape_data returning array of pointers to head of the output shape. |
| * \param aux_shape_size sizeof the returning array of aux_shapes |
| * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape. |
| * \param aux_shape_data returning array of pointers to head of the auxiliary shape. |
| * \param complete whether infer shape completes or more information is needed. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym, |
| uint32_t num_args, |
| const char** keys, |
| const uint32_t *arg_ind_ptr, |
| const uint32_t *arg_shape_data, |
| uint32_t *in_shape_size, |
| const uint32_t **in_shape_ndim, |
| const uint32_t ***in_shape_data, |
| uint32_t *out_shape_size, |
| const uint32_t **out_shape_ndim, |
| const uint32_t ***out_shape_data, |
| uint32_t *aux_shape_size, |
| const uint32_t **aux_shape_ndim, |
| const uint32_t ***aux_shape_data, |
| int *complete); |
| |
| /*! |
| * \brief partially infer shape of unknown input shapes given the known one. |
| * |
| * Return partially inferred results if not all shapes could be inferred. |
| * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data |
| * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=0 (by default) |
| * |
| * \param sym symbol handle |
| * \param num_args number of input arguments. |
| * \param keys the key of keyword args (optional) |
| * \param arg_ind_ptr the head pointer of the rows in CSR |
| * \param arg_shape_data the content of the CSR |
| * \param in_shape_size sizeof the returning array of in_shapes |
| * \param in_shape_ndim returning array of shape dimensions of each input shape. |
| * \param in_shape_data returning array of pointers to head of the input shape. |
| * \param out_shape_size sizeof the returning array of out_shapes |
| * \param out_shape_ndim returning array of shape dimensions of each output shape. |
| * \param out_shape_data returning array of pointers to head of the output shape. |
| * \param aux_shape_size sizeof the returning array of aux_shapes |
| * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape. |
| * \param aux_shape_data returning array of pointers to head of the auxiliary shape. |
| * \param complete whether infer shape completes or more information is needed. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym, |
| uint32_t num_args, |
| const char** keys, |
| const uint32_t *arg_ind_ptr, |
| const int *arg_shape_data, |
| uint32_t *in_shape_size, |
| const int **in_shape_ndim, |
| const int ***in_shape_data, |
| uint32_t *out_shape_size, |
| const int **out_shape_ndim, |
| const int ***out_shape_data, |
| uint32_t *aux_shape_size, |
| const int **aux_shape_ndim, |
| const int ***aux_shape_data, |
| int *complete); |
| |
| /*! |
| * \brief partially infer shape of unknown input shapes given the known one. |
| * |
| * Return partially inferred results if not all shapes could be inferred. |
| * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data |
| * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. |
| * This api is available when MXNet is built with flag |
| * USE_INT64_TENSOR_SIZE=1 (not default) i.e. Large Tensor Support |
| * |
| * \param sym symbol handle |
| * \param num_args number of input arguments. |
| * \param keys the key of keyword args (optional) |
| * \param arg_ind_ptr the head pointer of the rows in CSR |
| * \param arg_shape_data the content of the CSR |
| * \param in_shape_size sizeof the returning array of in_shapes |
| * \param in_shape_ndim returning array of shape dimensions of each input shape. |
| * \param in_shape_data returning array of pointers to head of the input shape. |
| * \param out_shape_size sizeof the returning array of out_shapes |
| * \param out_shape_ndim returning array of shape dimensions of each output shape. |
| * \param out_shape_data returning array of pointers to head of the output shape. |
| * \param aux_shape_size sizeof the returning array of aux_shapes |
| * \param aux_shape_ndim returning array of shape dimensions of each auxiliary shape. |
| * \param aux_shape_data returning array of pointers to head of the auxiliary shape. |
| * \param complete whether infer shape completes or more information is needed. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolInferShapePartialEx64(SymbolHandle sym, |
| uint32_t num_args, |
| const char** keys, |
| const int64_t *arg_ind_ptr, |
| const int64_t *arg_shape_data, |
| size_t *in_shape_size, |
| const int **in_shape_ndim, |
| const int64_t ***in_shape_data, |
| size_t *out_shape_size, |
| const int **out_shape_ndim, |
| const int64_t ***out_shape_data, |
| size_t *aux_shape_size, |
| const int **aux_shape_ndim, |
| const int64_t ***aux_shape_data, |
| int *complete); |
| |
| /*! |
| * \brief infer type of unknown input types given the known one. |
| * The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data |
| * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. |
| * |
| * \param sym symbol handle |
| * \param num_args numbe of input arguments. |
| * \param keys the key of keyword args (optional) |
| * \param arg_type_data the content of the CSR |
| * \param in_type_size sizeof the returning array of in_types |
| * \param in_type_data returning array of pointers to head of the input type. |
| * \param out_type_size sizeof the returning array of out_types |
| * \param out_type_data returning array of pointers to head of the output type. |
| * \param aux_type_size sizeof the returning array of aux_types |
| * \param aux_type_data returning array of pointers to head of the auxiliary type. |
| * \param complete whether infer type completes or more information is needed. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolInferType(SymbolHandle sym, |
| uint32_t num_args, |
| const char** keys, |
| const int *arg_type_data, |
| uint32_t *in_type_size, |
| const int **in_type_data, |
| uint32_t *out_type_size, |
| const int **out_type_data, |
| uint32_t *aux_type_size, |
| const int **aux_type_data, |
| int *complete); |
| |
| /*! |
| * \brief partially infer type of unknown input types given the known one. |
| * |
| * Return partially inferred results if not all types could be inferred. |
| * The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data |
| * The call will be treated as a kwargs call if key != NULL or num_args==0, otherwise it is positional. |
| * |
| * \param sym symbol handle |
| * \param num_args numbe of input arguments. |
| * \param keys the key of keyword args (optional) |
| * \param arg_type_data the content of the CSR |
| * \param in_type_size sizeof the returning array of in_types |
| * \param in_type_data returning array of pointers to head of the input type. |
| * \param out_type_size sizeof the returning array of out_types |
| * \param out_type_data returning array of pointers to head of the output type. |
| * \param aux_type_size sizeof the returning array of aux_types |
| * \param aux_type_data returning array of pointers to head of the auxiliary type. |
| * \param complete whether infer type completes or more information is needed. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXSymbolInferTypePartial(SymbolHandle sym, |
| uint32_t num_args, |
| const char** keys, |
| const int *arg_type_data, |
| uint32_t *in_type_size, |
| const int **in_type_data, |
| uint32_t *out_type_size, |
| const int **out_type_data, |
| uint32_t *aux_type_size, |
| const int **aux_type_data, |
| int *complete); |
| |
| /*! |
| * \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8 |
| * \param sym_handle symbol to be converted |
| * \param ret_sym_handle quantized symbol result |
| * \param dev_type device type |
| * \param num_excluded_sym_names number of layers excluded from being quantized in the input symbol |
| * \param excluded_sym_names node names to be excluded from being quantized |
| * \param num_excluded_op_names number of operators excluded from being quantized in the input symbol |
| * \param excluded_op_names operator names to be excluded from being quantized |
| * \param num_offline number of parameters that are quantized offline |
| * \param offline_params array of c strings representing the names of params quantized offline |
| * \param quantized_dtype the quantized destination type for input data |
| * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could |
| * \param quantize_mode quantize mode to be used in quantize pass |
| * \param quantize_granularity quantize granularity, tensor-wise or channel-wise |
| * \param out_num_calib_names return the number of nodes to be calibrated |
| * \param out_calib_names return the node names to be calibrated |
| */ |
| MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, |
| SymbolHandle *ret_sym_handle, |
| const int* dev_type, |
| const uint32_t num_excluded_sym_names, |
| const char **excluded_sym_names, |
| const uint32_t num_excluded_op_names, |
| const char **excluded_op_names, |
| const uint32_t num_offline, const char **offline_params, |
| const char *quantized_dtype, const bool calib_quantize, |
| const char *quantize_mode, const char *quantize_granularity, |
| uint32_t* out_num_calib_names, const char ***out_calib_names); |
| |
| /*! |
| * \brief Convert a symbol into a mixed precision symbol with cast operators for target dtype casting |
| * \param sym_handle symbol to be converted |
| * \param ret_sym_handle mixed precision symbol result |
| * \param num_args number of arguments for known dtypes |
| * \param arg_type_data arg types of the arguments |
| * \param target_dtype target_dtype for mixed precision symbol |
| * \param cast_optional_params whether to cast optional params to target_dtype |
| * \param num_target_dtype_op_names number of ops to be casted to target_dtype |
| * \param num_fp32_op_names number of ops to be casted to FP32 |
| * \param num_widest_dtype_op_names number of ops to be casted to widest dtype |
| * \param num_conditional_fp32_op_names number of ops to be casted to FP32 based on a condition |
| * \param num_excluded_symbols number of symbols to be excluded from casting |
| * \param num_model_params number of model parameters |
| * \param num_widest_dtype_op_names number of ops to be casted to the widest dtype |
| * \param num_conditional_fp32_op_names number of ops to be cast to fp32 based on precision |
| * \param target_dtype_op_names op names to be casted to target_dtype |
| * \param fp32_op_names op names to be casted to fp32 |
| * \param widest_dtype_op_names names to be casted to widest dtype |
| * \param conditional_fp32_op_names names to be casted to FP32 conditionally |
| * \param excluded_symbols symbol names to be excluded from casting |
| * \param param_names param names for conditional FP32 casting |
| * \param param_values param values for conditional FP32 casting |
| * \param arg_names argument names for which type information is provided |
| * \param model_param_names names for model parameters |
| */ |
| MXNET_DLL int MXReducePrecisionSymbol(SymbolHandle sym_handle, |
| SymbolHandle *ret_sym_handle, |
| uint32_t num_args, |
| const int* arg_type_data, |
| uint32_t num_ind_ptr, |
| const int* ind_ptr, |
| const int* target_dtype, |
| const int cast_optional_params, |
| const uint32_t num_target_dtype_op_names, |
| const uint32_t num_fp32_op_names, |
| const uint32_t num_widest_dtype_op_names, |
| const uint32_t num_conditional_fp32_op_names, |
| const uint32_t num_excluded_symbols, |
| const uint32_t num_model_params, |
| const char **target_dtype_op_names, |
| const char **fp32_op_names, |
| const char **widest_dtype_op_names, |
| const char **conditional_fp32_op_names, |
| const char **excluded_symbols, |
| const char **conditional_param_names, |
| const char **conditional_param_vals, |
| const char **model_param_names, |
| const char **arg_names); |
| /*! |
| * \brief Set calibration table to node attributes in the sym |
| * \param sym_handle symbol whose node attributes are to be set by calibration table |
| * \param num_layers number of layers in the calibration table |
| * \param layer names stored as keys in the calibration table |
| * \param low_quantiles low quantiles of layers stored in the calibration table |
| * \param high_quantiles high quantiles of layers stored in the calibration table |
| * \param ret_sym_handle returned symbol |
| */ |
| MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, |
| const uint32_t num_layers, |
| const char** layer_names, |
| const float* low_quantiles, |
| const float* high_quantiles, |
| SymbolHandle* ret_sym_handle); |
| |
| /*! |
| * \brief Run subgraph pass based on the backend provided |
| * \param sym_handle symbol to be converted |
| * \param backend backend names for subgraph pass |
| * \param ret_sym_handle returned symbol |
| */ |
| MXNET_DLL int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend, |
| SymbolHandle *ret_sym_handle); |
| |
| /*! |
| * \brief Generate atomic symbol (able to be composed) from a source symbol |
| * \param sym_handle source symbol |
| * \param ret_sym_handle returned atomic symbol |
| */ |
| MXNET_DLL int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle); |
| /*! |
| * \brief Partitions symbol for given backend, potentially creating subgraphs |
| * \param sym_handle symbol to be partitioned |
| * \param dev_type context device type |
| * \param backend_name backend name |
| * \param ret_sym_handle partitioned symbol returned |
| * \param len number of args |
| * \param in_args_handle args array |
| * \param num_options number of key value pairs |
| * \param keys keys for options |
| * \param vals values corresponding to keys |
| * \param num_input_shapes number of input shapes |
| * \param input_shape_names names of the input shapes |
| * \param input_shape_data pointer to the contiguous data shapes |
| * \param input_shape_idx array of per shape starting idx, the shape length for the i-th input shape |
| * is calculate as input_shape_idx[i+1] - input_shape_idx[i] |
| * \param num_input_dtypes number of input data types |
| * \param input_dtype_names array of names of the input data types |
| * \param input_dtypes array of values of the input data types |
| * \param num_input_stypesnumber of input storage types |
| * \param input_stype_names array of names of the input storage types |
| * \param input_stypes array of values of input storage types |
| * \param skip_infer if the optimization should skip the attribute inferences |
| * (to use if the backend does not require shape inference) |
| * \param new_args_cnt pointer a number to store the number of new args |
| * \param new_args_handle pointer on array to store the new args handles |
| * \param new_arg_names_handle pointer on array to store the new args names |
| * \param new_aux_cnt pointer a number to store the number of new aux |
| * \param new_aux_handle pointer on array to store the new aux handles |
| * \param new_aux_names_handle pointer on array to store the new aux names |
| */ |
| MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle, |
| const char* backend_name, |
| const int dev_type, |
| SymbolHandle* ret_sym_handle, |
| const mx_uint args_len, |
| NDArrayHandle* in_args_handle, |
| const mx_uint aux_len, |
| NDArrayHandle* in_aux_handle, |
| const mx_uint num_options, |
| const char** keys, |
| const char** vals, |
| const uint32_t num_input_shapes, |
| const char** input_shape_names, |
| const int64_t* input_shape_data, |
| const uint32_t* input_shape_idx, |
| const uint32_t num_input_dtypes, |
| const char** input_dtype_names, |
| const int* input_dtypes, |
| const uint32_t num_input_stypes, |
| const char** input_stype_names, |
| const int* input_stypes, |
| bool skip_infer, |
| int* new_args_cnt, |
| NDArrayHandle** new_args_handle, |
| char*** new_arg_names_handle, |
| int* new_aux_cnt, |
| NDArrayHandle** new_aux_handle, |
| char*** new_aux_names_handle); |
| |
| |
| //-------------------------------------------- |
| // Part 4: Executor interface |
| //-------------------------------------------- |
| /*! |
| * \brief Delete the executor |
| * \param handle the executor. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXExecutorFree(ExecutorHandle handle); |
| /*! |
| * \brief Print the content of execution plan, used for debug. |
| * \param handle the executor. |
| * \param out_str pointer to hold the output string of the printing. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXExecutorPrint(ExecutorHandle handle, const char **out_str); |
| /*! |
| * \brief Executor forward method |
| * |
| * \param handle executor handle |
| * \param is_train int value to indicate whether the forward pass is for evaluation |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXExecutorForward(ExecutorHandle handle, int is_train); |
| /*! |
| * \brief Excecutor run backward |
| * |
| * \param handle execute handle |
| * \param len lenth |
| * \param head_grads NDArray handle for heads' gradient |
| * |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXExecutorBackward(ExecutorHandle handle, |
| uint32_t len, |
| NDArrayHandle *head_grads); |
| /*! |
| * \brief Excecutor run backward |
| * |
| * \param handle execute handle |
| * \param len lenth |
| * \param head_grads NDArray handle for heads' gradient |
| * \param is_train int value to indicate whether the backward pass is for evaluation |
| * |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXExecutorBackwardEx(ExecutorHandle handle, |
| uint32_t len, |
| NDArrayHandle *head_grads, |
| int is_train); |
| /*! |
| * \brief Get executor's head NDArray |
| * |
| * \param handle executor handle |
| * \param out_size output narray vector size |
| * \param out out put narray handles |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle, |
| uint32_t *out_size, |
| NDArrayHandle **out); |
| |
| /*! |
| * \brief Generate Executor from symbol |
| * |
| * \param symbol_handle symbol handle |
| * \param dev_type device type |
| * \param dev_id device id |
| * \param len length |
| * \param in_args in args array |
| * \param arg_grad_store arg grads handle array |
| * \param grad_req_type grad req array |
| * \param aux_states_len length of auxiliary states |
| * \param aux_states auxiliary states array |
| * \param out output executor handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle, |
| int dev_type, |
| int dev_id, |
| uint32_t len, |
| NDArrayHandle *in_args, |
| NDArrayHandle *arg_grad_store, |
| uint32_t *grad_req_type, |
| uint32_t aux_states_len, |
| NDArrayHandle *aux_states, |
| ExecutorHandle *out); |
| /*! |
| * \brief Generate Executor from symbol, |
| * This is advanced function, allow specify group2ctx map. |
| * The user can annotate "ctx_group" attribute to name each group. |
| * |
| * \param symbol_handle symbol handle |
| * \param dev_type device type of default context |
| * \param dev_id device id of default context |
| * \param num_map_keys size of group2ctx map |
| * \param map_keys keys of group2ctx map |
| * \param map_dev_types device type of group2ctx map |
| * \param map_dev_ids device id of group2ctx map |
| * \param len length |
| * \param in_args in args array |
| * \param arg_grad_store arg grads handle array |
| * \param grad_req_type grad req array |
| * \param aux_states_len length of auxiliary states |
| * \param aux_states auxiliary states array |
| * \param out output executor handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXExecutorBindX(SymbolHandle symbol_handle, |
| int dev_type, |
| int dev_id, |
| uint32_t num_map_keys, |
| const char** map_keys, |
| const int* map_dev_types, |
| const int* map_dev_ids, |
| uint32_t len, |
| NDArrayHandle *in_args, |
| NDArrayHandle *arg_grad_store, |
| uint32_t *grad_req_type, |
| uint32_t aux_states_len, |
| NDArrayHandle *aux_states, |
| ExecutorHandle *out); |
| /*! |
| * \brief Generate Executor from symbol, |
| * This is advanced function, allow specify group2ctx map. |
| * The user can annotate "ctx_group" attribute to name each group. |
| * |
| * \param symbol_handle symbol handle |
| * \param dev_type device type of default context |
| * \param dev_id device id of default context |
| * \param num_map_keys size of group2ctx map |
| * \param map_keys keys of group2ctx map |
| * \param map_dev_types device type of group2ctx map |
| * \param map_dev_ids device id of group2ctx map |
| * \param len length |
| * \param in_args in args array |
| * \param arg_grad_store arg grads handle array |
| * \param grad_req_type grad req array |
| * \param aux_states_len length of auxiliary states |
| * \param aux_states auxiliary states array |
| * \param shared_exec input executor handle for memory sharing |
| * \param out output executor handle |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle, |
| int dev_type, |
| int dev_id, |
| uint32_t num_map_keys, |
| const char** map_keys, |
| const int* map_dev_types, |
| const int* map_dev_ids, |
| uint32_t len, |
| NDArrayHandle *in_args, |
| NDArrayHandle *arg_grad_store, |
| uint32_t *grad_req_type, |
| uint32_t aux_states_len, |
| NDArrayHandle *aux_states, |
| ExecutorHandle shared_exec, |
| ExecutorHandle *out); |
| /*! \brief DEPRECATED. Use MXExecutorSimpleBindEx instead. |
| */ |
| MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, |
| int dev_type, |
| int dev_id, |
| const uint32_t num_g2c_keys, |
| const char** g2c_keys, |
| const int* g2c_dev_types, |
| const int* g2c_dev_ids, |
| const uint32_t provided_grad_req_list_len, |
| const char** provided_grad_req_names, |
| const char** provided_grad_req_types, |
| const uint32_t num_provided_arg_shapes, |
| const char** provided_arg_shape_names, |
| const uint32_t* provided_arg_shape_data, |
| const uint32_t* provided_arg_shape_idx, |
| const uint32_t num_provided_arg_dtypes, |
| const char** provided_arg_dtype_names, |
| const int* provided_arg_dtypes, |
| const uint32_t num_provided_arg_stypes, |
| const char** provided_arg_stype_names, |
| const int* provided_arg_stypes, |
| const uint32_t num_shared_arg_names, |
| const char** shared_arg_name_list, |
| int* shared_buffer_len, |
| const char** shared_buffer_name_list, |
| NDArrayHandle* shared_buffer_handle_list, |
| const char*** updated_shared_buffer_name_list, |
| NDArrayHandle** updated_shared_buffer_handle_list, |
| uint32_t* num_in_args, |
| NDArrayHandle** in_args, |
| NDArrayHandle** arg_grads, |
| uint32_t* num_aux_states, |
| NDArrayHandle** aux_states, |
| ExecutorHandle shared_exec_handle, |
| ExecutorHandle* out); |
| |
| |
| MXNET_DLL int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, |
| int dev_type, |
| int dev_id, |
| const uint32_t num_g2c_keys, |
| const char** g2c_keys, |
| const int* g2c_dev_types, |
| const int* g2c_dev_ids, |
| const uint32_t provided_grad_req_list_len, |
| const char** provided_grad_req_names, |
| const char** provided_grad_req_types, |
| const uint32_t num_provided_arg_shapes, |
| const char** provided_arg_shape_names, |
| const int* provided_arg_shape_data, |
| const uint32_t* provided_arg_shape_idx, |
| const uint32_t num_provided_arg_dtypes, |
| const char** provided_arg_dtype_names, |
| const int* provided_arg_dtypes, |
| const uint32_t num_provided_arg_stypes, |
| const char** provided_arg_stype_names, |
| const int* provided_arg_stypes, |
| const uint32_t num_shared_arg_names, |
| const char** shared_arg_name_list, |
| int* shared_buffer_len, |
| const char** shared_buffer_name_list, |
| NDArrayHandle* shared_buffer_handle_list, |
| const char*** updated_shared_buffer_name_list, |
| NDArrayHandle** updated_shared_buffer_handle_list, |
| uint32_t* num_in_args, |
| NDArrayHandle** in_args, |
| NDArrayHandle** arg_grads, |
| uint32_t* num_aux_states, |
| NDArrayHandle** aux_states, |
| ExecutorHandle shared_exec_handle, |
| ExecutorHandle* out); |
| |
| |
| MXNET_DLL int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle, |
| int dev_type, |
| int dev_id, |
| const uint32_t num_g2c_keys, |
| const char** g2c_keys, |
| const int* g2c_dev_types, |
| const int* g2c_dev_ids, |
| const uint32_t provided_grad_req_list_len, |
| const char** provided_grad_req_names, |
| const char** provided_grad_req_types, |
| const uint32_t num_provided_arg_shapes, |
| const char** provided_arg_shape_names, |
| const int64_t* provided_arg_shape_data, |
| const uint32_t* provided_arg_shape_idx, |
| const uint32_t num_provided_arg_dtypes, |
| const char** provided_arg_dtype_names, |
| const int* provided_arg_dtypes, |
| const uint32_t num_provided_arg_stypes, |
| const char** provided_arg_stype_names, |
| const int* provided_arg_stypes, |
| const uint32_t num_shared_arg_names, |
| const char** shared_arg_name_list, |
| int* shared_buffer_len, |
| const char** shared_buffer_name_list, |
| NDArrayHandle* shared_buffer_handle_list, |
| const char*** updated_shared_buffer_name_list, |
| NDArrayHandle** updated_shared_buffer_handle_list, |
| uint32_t* num_in_args, |
| NDArrayHandle** in_args, |
| NDArrayHandle** arg_grads, |
| uint32_t* num_aux_states, |
| NDArrayHandle** aux_states, |
| ExecutorHandle shared_exec_handle, |
| ExecutorHandle* out); |
| |
| |
| /*! |
| * \brief DEPRECATED. Use MXExecutorReshapeEx instead. |
| * Return a new executor with the same symbol and shared memory, |
| * but different input/output shapes. |
| * |
| * \param partial_shaping Whether to allow changing the shape of unspecified arguments. |
| * \param allow_up_sizing Whether to allow allocating new ndarrays that's larger than the original. |
| * \param dev_type device type of default context |
| * \param dev_id device id of default context |
| * \param num_map_keys size of group2ctx map |
| * \param map_keys keys of group2ctx map |
| * \param map_dev_types device type of group2ctx map |
| * \param map_dev_ids device id of group2ctx map |
| * \param num_in_args length of in_args |
| * \param in_args in args array |
| * \param arg_grads arg grads handle array |
| * \param num_aux_states length of auxiliary states |
| * \param aux_states auxiliary states array |
| * \param shared_exec input executor handle for memory sharing |
| * \param out output executor handle |
| * \return a new executor |
| */ |
| MXNET_DLL int MXExecutorReshape(int partial_shaping, |
| int allow_up_sizing, |
| int dev_type, |
| int dev_id, |
| uint32_t num_map_keys, |
| const char** map_keys, |
| const int* map_dev_types, |
| const int* map_dev_ids, |
| const uint32_t num_provided_arg_shapes, |
| const char** provided_arg_shape_names, |
| const uint32_t* provided_arg_shape_data, |
| const uint32_t* provided_arg_shape_idx, |
| uint32_t* num_in_args, |
| NDArrayHandle** in_args, |
| NDArrayHandle** arg_grads, |
| uint32_t* num_aux_states, |
| NDArrayHandle** aux_states, |
| ExecutorHandle shared_exec, |
| ExecutorHandle *out); |
| /*! |
| * \brief Return a new executor with the same symbol and shared memory, |
| * but different input/output shapes. |
| * |
| * \param partial_shaping Whether to allow changing the shape of unspecified arguments. |
| * \param allow_up_sizing Whether to allow allocating new ndarrays that's larger than the original. |
| * \param dev_type device type of default context |
| * \param dev_id device id of default context |
| * \param num_map_keys size of group2ctx map |
| * \param map_keys keys of group2ctx map |
| * \param map_dev_types device type of group2ctx map |
| * \param map_dev_ids device id of group2ctx map |
| * \param num_in_args length of in_args |
| * \param in_args in args array |
| * \param arg_grads arg grads handle array |
| * \param num_aux_states length of auxiliary states |
| * \param aux_states auxiliary states array |
| * \param shared_exec input executor handle for memory sharing |
| * \param out output executor handle |
| * \return a new executor |
| */ |
| MXNET_DLL int MXExecutorReshapeEx(int partial_shaping, |
| int allow_up_sizing, |
| int dev_type, |
| int dev_id, |
| uint32_t num_map_keys, |
| const char** map_keys, |
| const int* map_dev_types, |
| const int* map_dev_ids, |
| const uint32_t num_provided_arg_shapes, |
| const char** provided_arg_shape_names, |
| const int* provided_arg_shape_data, |
| const uint32_t* provided_arg_shape_idx, |
| uint32_t* num_in_args, |
| NDArrayHandle** in_args, |
| NDArrayHandle** arg_grads, |
| uint32_t* num_aux_states, |
| NDArrayHandle** aux_states, |
| ExecutorHandle shared_exec, |
| ExecutorHandle *out); |
| |
| /*! |
| * \brief get optimized graph from graph executor |
| */ |
| MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, |
| SymbolHandle *out); |
| /*! |
| * \brief set a call back to notify the completion of operation |
| */ |
| MXNET_DLL int MXExecutorSetMonitorCallback(ExecutorHandle handle, |
| ExecutorMonitorCallback callback, |
| void* callback_handle); |
| |
| /*! |
| * \brief set a call back to notify the completion of operation |
| * \param monitor_all If true, monitor both input and output, otherwise monitor output only. |
| */ |
| MXNET_DLL int MXExecutorSetMonitorCallbackEX(ExecutorHandle handle, |
| ExecutorMonitorCallback callback, |
| void *callback_handle, bool monitor_all); |
| //-------------------------------------------- |
| // Part 5: IO Interface |
| //-------------------------------------------- |
| /*! |
| * \brief List all the available iterator entries |
| * \param out_size the size of returned iterators |
| * \param out_array the output iteratos entries |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXListDataIters(uint32_t *out_size, |
| DataIterCreator **out_array); |
| /*! |
| * \brief Init an iterator, init with parameters |
| * the array size of passed in arguments |
| * \param handle of the iterator creator |
| * \param num_param number of parameter |
| * \param keys parameter keys |
| * \param vals parameter values |
| * \param out resulting iterator |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle, |
| uint32_t num_param, |
| const char **keys, |
| const char **vals, |
| DataIterHandle *out); |
| /*! |
| * \brief Get the detailed information about data iterator. |
| * \param creator the DataIterCreator. |
| * \param name The returned name of the creator. |
| * \param description The returned description of the symbol. |
| * \param num_args Number of arguments. |
| * \param arg_names Name of the arguments. |
| * \param arg_type_infos Type informations about the arguments. |
| * \param arg_descriptions Description information about the arguments. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXDataIterGetIterInfo(DataIterCreator creator, |
| const char **name, |
| const char **description, |
| uint32_t *num_args, |
| const char ***arg_names, |
| const char ***arg_type_infos, |
| const char ***arg_descriptions); |
| /*! |
| * \brief Free the handle to the IO module |
| * \param handle the handle pointer to the data iterator |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXDataIterFree(DataIterHandle handle); |
| /*! |
| * \brief Move iterator to next position |
| * \param handle the handle to iterator |
| * \param out return value of next |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXDataIterNext(DataIterHandle handle, |
| int *out); |
| /*! |
| * \brief Call iterator.Reset |
| * \param handle the handle to iterator |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle); |
| |
| /*! |
| * \brief Get the handle to the NDArray of underlying data |
| * \param handle the handle pointer to the data iterator |
| * \param out handle to underlying data NDArray |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXDataIterGetData(DataIterHandle handle, |
| NDArrayHandle *out); |
| /*! |
| * \brief Get the image index by array. |
| * \param handle the handle pointer to the data iterator |
| * \param out_index output index of the array. |
| * \param out_size output size of the array. |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, |
| uint64_t **out_index, |
| uint64_t *out_size); |
| /*! |
| * \brief Get the padding number in current data batch |
| * \param handle the handle pointer to the data iterator |
| * \param pad pad number ptr |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXDataIterGetPadNum(DataIterHandle handle, |
| int *pad); |
| |
| /*! |
| * \brief Get the handle to the NDArray of underlying label |
| * \param handle the handle pointer to the data iterator |
| * \param out the handle to underlying label NDArray |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle, |
| NDArrayHandle *out); |
| //-------------------------------------------- |
| // Part 6: basic KVStore interface |
| //-------------------------------------------- |
| /*! |
| * \brief Initialized ps-lite environment variables |
| * \param num_vars number of variables to initialize |
| * \param keys environment keys |
| * \param vals environment values |
| */ |
| MXNET_DLL int MXInitPSEnv(uint32_t num_vars, |
| const char **keys, |
| const char **vals); |
| |
| |
| /*! |
| * \brief Create a kvstore |
| * \param type the type of KVStore |
| * \param out The output type of KVStore |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreCreate(const char *type, |
| KVStoreHandle *out); |
| |
| /*! |
| * \brief Set parameters to use low-bit compressed gradients |
| * \param handle handle to the kvstore |
| * \param keys keys for compression parameters |
| * \param vals values for compression parameters |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreSetGradientCompression(KVStoreHandle handle, |
| uint32_t num_params, |
| const char** keys, |
| const char** vals); |
| |
| /*! |
| * \brief Delete a KVStore handle. |
| * \param handle handle to the kvstore |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreFree(KVStoreHandle handle); |
| /*! |
| * \brief Init a list of (key,value) pairs in kvstore |
| * \param handle handle to the kvstore |
| * \param num the number of key-value pairs |
| * \param keys the list of keys |
| * \param vals the list of values |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreInit(KVStoreHandle handle, |
| uint32_t num, |
| const int* keys, |
| NDArrayHandle* vals); |
| |
| /*! |
| * \brief Init a list of (key,value) pairs in kvstore, where each key is a string |
| * \param handle handle to the kvstore |
| * \param num the number of key-value pairs |
| * \param keys the list of keys |
| * \param vals the list of values |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreInitEx(KVStoreHandle handle, |
| uint32_t num, |
| const char** keys, |
| NDArrayHandle* vals); |
| |
| /*! |
| * \brief Push a list of (key,value) pairs to kvstore |
| * \param handle handle to the kvstore |
| * \param num the number of key-value pairs |
| * \param keys the list of keys |
| * \param vals the list of values |
| * \param priority the priority of the action |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStorePush(KVStoreHandle handle, |
| uint32_t num, |
| const int* keys, |
| NDArrayHandle* vals, |
| int priority); |
| /*! |
| * \brief Push a list of (key,value) pairs to kvstore, where each key is a string |
| * \param handle handle to the kvstore |
| * \param num the number of key-value pairs |
| * \param keys the list of keys |
| * \param vals the list of values |
| * \param priority the priority of the action |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle, |
| uint32_t num, |
| const char** keys, |
| NDArrayHandle* vals, |
| int priority); |
| /*! |
| * \brief pull a list of (key, value) pairs from the kvstore |
| * \param handle handle to the kvstore |
| * \param num the number of key-value pairs |
| * \param keys the list of keys |
| * \param vals the list of values |
| * \param priority the priority of the action |
| * \param ignore_sparse whether to ignore sparse arrays in the request |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStorePullWithSparse(KVStoreHandle handle, |
| uint32_t num, |
| const int* keys, |
| NDArrayHandle* vals, |
| int priority, |
| bool ignore_sparse); |
| /*! |
| * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string |
| * \param handle handle to the kvstore |
| * \param num the number of key-value pairs |
| * \param keys the list of keys |
| * \param vals the list of values |
| * \param priority the priority of the action |
| * \param ignore_sparse whether to ignore sparse arrays in the request |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStorePullWithSparseEx(KVStoreHandle handle, |
| uint32_t num, |
| const char** keys, |
| NDArrayHandle* vals, |
| int priority, |
| bool ignore_sparse); |
| /*! |
| * \brief pull a list of (key, value) pairs from the kvstore |
| * \param handle handle to the kvstore |
| * \param num the number of key-value pairs |
| * \param keys the list of keys |
| * \param vals the list of values |
| * \param priority the priority of the action |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStorePull(KVStoreHandle handle, |
| uint32_t num, |
| const int* keys, |
| NDArrayHandle* vals, |
| int priority); |
| /*! |
| * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string |
| * \param handle handle to the kvstore |
| * \param num the number of key-value pairs |
| * \param keys the list of keys |
| * \param vals the list of values |
| * \param priority the priority of the action |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle, |
| uint32_t num, |
| const char** keys, |
| NDArrayHandle* vals, |
| int priority); |
| |
| /*! |
| * \brief pull a list of (key, value) pairs from the kvstore, where each key is an integer. |
| * The NDArray pulled back will be in row_sparse storage with only the specified |
| * row_ids present based row_ids (others rows are zeros). |
| * \param handle handle to the kvstore |
| * \param num the number of key-value pairs |
| * \param keys the list of keys |
| * \param vals the list of values |
| * \param row_ids the list of row_id NDArrays |
| * \param priority the priority of the action |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStorePullRowSparse(KVStoreHandle handle, |
| uint32_t num, |
| const int* keys, |
| NDArrayHandle* vals, |
| const NDArrayHandle* row_ids, |
| int priority); |
| /*! |
| * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string. |
| * The NDArray pulled back will be in row_sparse storage with only the specified |
| * row_ids present based row_ids (others rows are zeros). |
| * \param handle handle to the kvstore |
| * \param num the number of key-value pairs |
| * \param keys the list of keys |
| * \param vals the list of values |
| * \param row_ids the list of row_id NDArrays |
| * \param priority the priority of the action |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStorePullRowSparseEx(KVStoreHandle handle, |
| uint32_t num, |
| const char** keys, |
| NDArrayHandle* vals, |
| const NDArrayHandle* row_ids, |
| int priority); |
| |
| /*! |
| * \brief broadcast a list of (key, value) pairs from the kvstore |
| * \param handle handle to the kvstore |
| * \param vnum the number of key-value pairs corresponding to vkeys |
| * \param vkeys the list of keys for the values to be pushed |
| * \param onum the number of key-value pairs corresponding to okeys |
| * \param okeys the list of keys for the values to be pulled |
| * \param vals the list of values |
| * \param outs the list of outputs |
| * \param priority the priority of the action |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreBroadcast(KVStoreHandle handle, |
| mx_uint vnum, |
| const int* vkeys, |
| mx_uint onum, |
| const int* okeys, |
| NDArrayHandle* vals, |
| NDArrayHandle* outs, |
| int priority); |
| /*! |
| * \brief broadcast a list of (key, value) pairs from the kvstore, |
| * where each key is a string |
| * \param handle handle to the kvstore |
| * \param vnum the number of key-value pairs corresponding to vkeys |
| * \param vkeys the list of keys for the values to be pushed |
| * \param onum the number of key-value pairs corresponding to okeys |
| * \param okeys the list of keys for the values to be pulled |
| * \param vals the list of values |
| * \param outs the list of outputs |
| * \param priority the priority of the action |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreBroadcastEx(KVStoreHandle handle, |
| mx_uint vnum, |
| const char** vkeys, |
| mx_uint onum, |
| const char** okeys, |
| NDArrayHandle* vals, |
| NDArrayHandle* outs, |
| int priority); |
| |
| /*! |
| * \brief push and pull a list of (key, value) pairs from the kvstore |
| * \param handle handle to the kvstore |
| * \param vnum the number of key-value pairs corresponding to vkeys |
| * \param vkeys the list of keys for the values to be pushed |
| * \param onum the number of key-value pairs corresponding to okeys |
| * \param okeys the list of keys for the values to be pulled |
| * \param vals the list of values |
| * \param outs the list of outputs |
| * \param priority the priority of the action |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStorePushPull(KVStoreHandle handle, |
| mx_uint vnum, |
| const int* vkeys, |
| mx_uint onum, |
| const int* okeys, |
| NDArrayHandle* vals, |
| NDArrayHandle* outs, |
| int priority); |
| /*! |
| * \brief push and pull a list of (key, value) pairs from the kvstore, |
| * where each key is a string |
| * \param handle handle to the kvstore |
| * \param vnum the number of key-value pairs corresponding to vkeys |
| * \param vkeys the list of keys for the values to be pushed |
| * \param onum the number of key-value pairs corresponding to okeys |
| * \param okeys the list of keys for the values to be pulled |
| * \param vals the list of values |
| * \param outs the list of outputs |
| * \param priority the priority of the action |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStorePushPullEx(KVStoreHandle handle, |
| mx_uint vnum, |
| const char** vkeys, |
| mx_uint onum, |
| const char** okeys, |
| NDArrayHandle* vals, |
| NDArrayHandle* outs, |
| int priority); |
| |
| /*! |
| * \brief user-defined updater for the kvstore |
| * It's this updater's responsibility to delete \a recv and \a local |
| * \param the key |
| * \param recv the pushed value on this key |
| * \param local the value stored on local on this key |
| * \param handle The additional handle to the updater |
| */ |
| typedef void (MXKVStoreUpdater)(int key, |
| NDArrayHandle recv, |
| NDArrayHandle local, |
| void *handle); |
| /*! |
| * \brief user-defined updater for the kvstore with string keys |
| * It's this updater's responsibility to delete \a recv and \a local |
| * \param the key |
| * \param recv the pushed value on this key |
| * \param local the value stored on local on this key |
| * \param handle The additional handle to the updater |
| */ |
| typedef void (MXKVStoreStrUpdater)(const char* key, |
| NDArrayHandle recv, |
| NDArrayHandle local, |
| void *handle); |
| /*! |
| * \brief register a push updater |
| * \param handle handle to the KVStore |
| * \param updater udpater function |
| * \param updater_handle The additional handle used to invoke the updater |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreSetUpdater(KVStoreHandle handle, |
| MXKVStoreUpdater updater, |
| void *updater_handle); |
| /*! |
| * \brief register a push updater with int keys and one with string keys |
| * \param handle handle to the KVStore |
| * \param updater updater function with int keys |
| * \param str_updater updater function with string keys |
| * \param updater_handle The additional handle used to invoke the updater |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreSetUpdaterEx(KVStoreHandle handle, |
| MXKVStoreUpdater updater, |
| MXKVStoreStrUpdater str_updater, |
| void *updater_handle); |
| /*! |
| * \brief get the type of the kvstore |
| * \param handle handle to the KVStore |
| * \param type a string type |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreGetType(KVStoreHandle handle, |
| const char** type); |
| //-------------------------------------------- |
| // Part 6: advanced KVStore for multi-machines |
| //-------------------------------------------- |
| |
| /** |
| * \brief return The rank of this node in its group, which is in [0, GroupSize). |
| * |
| * \param handle handle to the KVStore |
| * \param ret the node rank |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreGetRank(KVStoreHandle handle, |
| int *ret); |
| |
| /** |
| * \brief return The number of nodes in this group, which is |
| * - number of workers if if `IsWorkerNode() == true`, |
| * - number of servers if if `IsServerNode() == true`, |
| * - 1 if `IsSchedulerNode() == true`, |
| * \param handle handle to the KVStore |
| * \param ret the group size |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreGetGroupSize(KVStoreHandle handle, |
| int *ret); |
| |
| /** |
| * \brief return whether or not this process is a worker node. |
| * \param ret 1 for yes, 0 for no |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreIsWorkerNode(int *ret); |
| |
| |
| /** |
| * \brief return whether or not this process is a server node. |
| * \param ret 1 for yes, 0 for no |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreIsServerNode(int *ret); |
| |
| |
| /** |
| * \brief return whether or not this process is a scheduler node. |
| * \param ret 1 for yes, 0 for no |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreIsSchedulerNode(int *ret); |
| |
| /** |
| * \brief global barrier among all worker machines |
| * |
| * \param handle handle to the KVStore |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreBarrier(KVStoreHandle handle); |
| |
| /** |
| * \brief whether to do barrier when finalize |
| * |
| * \param handle handle to the KVStore |
| * \param barrier_before_exit whether to do barrier when kvstore finalize |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreSetBarrierBeforeExit(KVStoreHandle handle, |
| const int barrier_before_exit); |
| |
| /** |
| * \brief the prototype of a server controller |
| * \param head the head of the command |
| * \param body the body of the command |
| * \param controller_handle helper handle for implementing controller |
| */ |
| typedef void (MXKVStoreServerController)(int head, |
| const char *body, |
| void *controller_handle); |
| |
| /** |
| * \brief Run as server (or scheduler) |
| * \param handle handle to the KVStore |
| * \param controller the user-defined server controller |
| * \param controller_handle helper handle for implementing controller |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreRunServer(KVStoreHandle handle, |
| MXKVStoreServerController controller, |
| void *controller_handle); |
| |
| /** |
| * \brief Send a command to all server nodes |
| * \param handle handle to the KVStore |
| * \param cmd_id the head of the command |
| * \param cmd_body the body of the command |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXKVStoreSendCommmandToServers(KVStoreHandle handle, |
| int cmd_id, |
| const char* cmd_body); |
| |
| /** |
| * \brief Get the number of ps dead node(s) specified by {node_id} |
| * |
| * \param handle handle to the KVStore |
| * \param node_id Can be a node group or a single node. |
| * kScheduler = 1, kServerGroup = 2, kWorkerGroup = 4 |
| * \param number Ouptut number of dead nodes |
| * \param timeout_sec A node fails to send heartbeart in {timeout_sec} seconds |
| * will be presumed as 'dead' |
| */ |
| MXNET_DLL int MXKVStoreGetNumDeadNode(KVStoreHandle handle, |
| const int node_id, |
| int *number, |
| const int timeout_sec DEFAULT(60)); |
| |
| /** |
| * \brief Create a RecordIO writer object |
| * \param uri path to file |
| * \param out handle pointer to the created object |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXRecordIOWriterCreate(const char *uri, RecordIOHandle *out); |
| |
| /** |
| * \brief Delete a RecordIO writer object |
| * \param handle handle to RecordIO object |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXRecordIOWriterFree(RecordIOHandle handle); |
| |
| /** |
| * \brief Write a record to a RecordIO object |
| * \param handle handle to RecordIO object |
| * \param buf buffer to write |
| * \param size size of buffer |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXRecordIOWriterWriteRecord(RecordIOHandle handle, |
| const char *buf, size_t size); |
| |
| /** |
| * \brief Get the current writer pointer position |
| * \param handle handle to RecordIO object |
| * \param pos handle to output position |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXRecordIOWriterTell(RecordIOHandle handle, size_t *pos); |
| |
| /** |
| * \brief Create a RecordIO reader object |
| * \param uri path to file |
| * \param out handle pointer to the created object |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXRecordIOReaderCreate(const char *uri, RecordIOHandle *out); |
| |
| /** |
| * \brief Delete a RecordIO reader object |
| * \param handle handle to RecordIO object |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXRecordIOReaderFree(RecordIOHandle handle); |
| |
| /** |
| * \brief Write a record to a RecordIO object |
| * \param handle handle to RecordIO object |
| * \param buf pointer to return buffer |
| * \param size point to size of buffer |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXRecordIOReaderReadRecord(RecordIOHandle handle, |
| char const **buf, size_t *size); |
| |
| /** |
| * \brief Set the current reader pointer position |
| * \param handle handle to RecordIO object |
| * \param pos target position |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXRecordIOReaderSeek(RecordIOHandle handle, size_t pos); |
| |
| /** |
| * \brief Get the current writer pointer position |
| * \param handle handle to RecordIO object |
| * \param pos handle to output position |
| * \return 0 when success, -1 when failure happens |
| */ |
| MXNET_DLL int MXRecordIOReaderTell(RecordIOHandle handle, size_t *pos); |
| |
| /** |
| * \brief Create a MXRtc object |
| */ |
| MXNET_DLL int MXRtcCreate(char* name, uint32_t num_input, uint32_t num_output, |
| char** input_names, char** output_names, |
| NDArrayHandle* inputs, NDArrayHandle* outputs, |
| char* kernel, RtcHandle *out); |
| |
| /** |
| * \brief Run cuda kernel |
| */ |
| MXNET_DLL int MXRtcPush(RtcHandle handle, uint32_t num_input, uint32_t num_output, |
| NDArrayHandle* inputs, NDArrayHandle* outputs, |
| uint32_t gridDimX, |
| uint32_t gridDimY, |
| uint32_t gridDimZ, |
| uint32_t blockDimX, |
| uint32_t blockDimY, |
| uint32_t blockDimZ); |
| |
| /** |
| * \brief Delete a MXRtc object |
| */ |
| MXNET_DLL int MXRtcFree(RtcHandle handle); |
| /* |
| * \brief register custom operators from frontend. |
| * \param op_type name of custom op |
| * \param creator |
| */ |
| MXNET_DLL int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator); |
| /* |
| * \brief record custom function for backward later. |
| * \param num_inputs number of input NDArrays. |
| * \param inputs handle to input NDArrays. |
| * \param num_outputs number of output NDArrays. |
| * \param outputs handle to output NDArrays. |
| * \param callbacks callbacks for backward function. |
| */ |
| MXNET_DLL int MXCustomFunctionRecord(int num_inputs, NDArrayHandle *inputs, |
| int num_outputs, NDArrayHandle *outputs, |
| struct MXCallbackList *callbacks); |
| /* |
| * \brief create cuda rtc module |
| * \param source cuda source code |
| * \param num_options number of compiler flags |
| * \param options compiler flags |
| * \param num_exports number of exported function names |
| * \param exported function names |
| * \param out handle to created module |
| */ |
| MXNET_DLL int MXRtcCudaModuleCreate(const char* source, int num_options, |
| const char** options, int num_exports, |
| const char** exports, CudaModuleHandle *out); |
| /* |
| * \brief delete cuda rtc module |
| * \param handle handle to cuda module |
| */ |
| MXNET_DLL int MXRtcCudaModuleFree(CudaModuleHandle handle); |
| /* |
| * \brief get kernel from module |
| * \param handle handle to cuda module |
| * \param name name of kernel function |
| * \param num_args number of arguments |
| * \param is_ndarray whether argument is ndarray |
| * \param is_const whether argument is constant |
| * \param arg_types data type of arguments |
| * \param out created kernel |
| */ |
| MXNET_DLL int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char* name, |
| int num_args, int* is_ndarray, int* is_const, |
| int* arg_types, CudaKernelHandle *out); |
| /* |
| * \brief delete kernel |
| * \param handle handle to previously created kernel |
| */ |
| MXNET_DLL int MXRtcCudaKernelFree(CudaKernelHandle handle); |
| /* |
| * \brief launch cuda kernel |
| * \param handle handle to kernel |
| * \param dev_id (GPU) device id |
| * \param args pointer to arguments |
| * \param grid_dim_x grid dimension x |
| * \param grid_dim_y grid dimension y |
| * \param grid_dim_z grid dimension z |
| * \param block_dim_x block dimension x |
| * \param block_dim_y block dimension y |
| * \param block_dim_z block dimension z |
| * \param shared_mem size of dynamically allocated shared memory |
| */ |
| MXNET_DLL int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** args, |
| uint32_t grid_dim_x, uint32_t grid_dim_y, |
| uint32_t grid_dim_z, uint32_t block_dim_x, |
| uint32_t block_dim_y, uint32_t block_dim_z, |
| uint32_t shared_mem); |
| /*! |
| * \brief Get shared memory handle from NDArray |
| * \param handle NDArray handle. |
| * \param shared_pid output PID |
| * \param shared_id output shared memory id. |
| */ |
| MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid, |
| int* shared_id); |
| /*! |
| * \brief DEPRECATED. Use MXNDArrayCreateFromSharedMemEx instead. |
| * Reconstruct NDArray from shared memory handle |
| * \param shared_pid shared PID |
| * \param shared_id shared memory id |
| * \param shape pointer to NDArray dimensions |
| * \param ndim number of NDArray dimensions |
| * \param dtype data type of NDArray |
| * \param out constructed NDArray |
| */ |
| MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const uint32_t *shape, |
| uint32_t ndim, int dtype, NDArrayHandle *out); |
| |
| /*! |
| * \brief Release all unreferenced memory from the devices storage managers memory pool |
| * \param dev_type device type, specify device we want to take |
| * \param dev_id the device id of the specific device |
| */ |
| MXNET_DLL int MXStorageEmptyCache(int dev_type, int dev_id); |
| |
| /*! |
| * \brief Reconstruct NDArray from shared memory handle |
| * \param shared_pid shared PID |
| * \param shared_id shared memory id |
| * \param shape pointer to NDArray dimensions |
| * \param ndim number of NDArray dimensions |
| * \param dtype data type of NDArray |
| * \param out constructed NDArray |
| */ |
| MXNET_DLL int MXNDArrayCreateFromSharedMemEx(int shared_pid, int shared_id, const int *shape, |
| int ndim, int dtype, NDArrayHandle *out); |
| |
| /*! |
| * \brief Push an asynchronous operation to the engine. |
| * \param async_func Execution function whici takes a parameter on_complete |
| * that must be called when the execution ompletes. |
| * \param func_param The parameter set on calling async_func, can be NULL. |
| * \param deleter The callback to free func_param, can be NULL. |
| * \param ctx_handle Execution context. |
| * \param const_vars_handle The variables that current operation will use |
| * but not mutate. |
| * \param num_const_vars The number of const_vars_handle. |
| * \param mutable_vars_handle The variables that current operation will mutate. |
| * \param num_mutable_vars The number of mutable_vars_handle. |
| * \param prop_handle Property of the function. |
| * \param priority Priority of the action, as hint to the engine. |
| * \param opr_name The operation name. |
| * \param wait Whether this is a WaitForVar operation. |
| */ |
| MXNET_DLL int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param, |
| EngineFuncParamDeleter deleter, ContextHandle ctx_handle, |
| EngineVarHandle const_vars_handle, int num_const_vars, |
| EngineVarHandle mutable_vars_handle, int num_mutable_vars, |
| EngineFnPropertyHandle prop_handle DEFAULT(NULL), |
| int priority DEFAULT(0), const char* opr_name DEFAULT(NULL), |
| bool wait DEFAULT(false)); |
| |
| /*! |
| * \brief Push a synchronous operation to the engine. |
| * \param sync_func Execution function that executes the operation. |
| * \param func_param The parameter set on calling sync_func, can be NULL. |
| * \param deleter The callback to free func_param, can be NULL. |
| * \param ctx_handle Execution context. |
| * \param const_vars_handle The variables that current operation will use |
| * but not mutate. |
| * \param num_const_vars The number of const_vars_handle. |
| * \param mutable_vars_handle The variables that current operation will mutate. |
| * \param num_mutable_vars The number of mutable_vars_handle. |
| * \param prop_handle Property of the function. |
| * \param priority Priority of the action, as hint to the engine. |
| * \param opr_name The operation name. |
| */ |
| MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param, |
| EngineFuncParamDeleter deleter, ContextHandle ctx_handle, |
| EngineVarHandle const_vars_handle, int num_const_vars, |
| EngineVarHandle mutable_vars_handle, int num_mutable_vars, |
| EngineFnPropertyHandle prop_handle DEFAULT(NULL), |
| int priority DEFAULT(0), const char* opr_name DEFAULT(NULL)); |
| /*! |
| * \brief Create an NDArray from source sharing the same data chunk. |
| * \param src source NDArray |
| * \param out new NDArray sharing the same data chunck with src |
| */ |
| MXNET_DLL int MXShallowCopyNDArray(NDArrayHandle src, NDArrayHandle* out); |
| /*! |
| * \brief Create an Symbol from source sharing the same graph structure. |
| * \param src source Symbol |
| * \param out new Symbol sharing the same graph structure with src |
| */ |
| MXNET_DLL int MXShallowCopySymbol(SymbolHandle src, SymbolHandle * out); |
| |
| /*! |
| * \brief Push an asynchronous operation to the engine. |
| * \param async_func Execution function whici takes a parameter on_complete |
| * that must be called when the execution ompletes. |
| * \param func_param The parameter set on calling async_func, can be NULL. |
| * \param deleter The callback to free func_param, can be NULL. |
| * \param ctx_handle Execution context. |
| * \param const_nds_handle The NDArrays that current operation will use |
| * but not mutate. |
| * \param num_const_nds The number of const_nds_handle. |
| * \param mutable_nds_handle The NDArrays that current operation will mutate. |
| * \param num_mutable_nds The number of mutable_nds_handle. |
| * \param prop_handle Property of the function. |
| * \param priority Priority of the action, as hint to the engine. |
| * \param opr_name The operation name. |
| * \param wait Whether this is a WaitForVar operation. |
| */ |
| MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param, |
| EngineFuncParamDeleter deleter, ContextHandle ctx_handle, |
| NDArrayHandle* const_nds_handle, int num_const_nds, |
| NDArrayHandle* mutable_nds_handle, int num_mutable_nds, |
| EngineFnPropertyHandle prop_handle DEFAULT(NULL), |
| int priority DEFAULT(0), const char* opr_name DEFAULT(NULL), |
| bool wait DEFAULT(false)); |
| |
| /*! |
| * \brief Push a synchronous operation to the engine. |
| * \param sync_func Execution function that executes the operation. |
| * \param func_param The parameter set on calling sync_func, can be NULL. |
| * \param deleter The callback to free func_param, can be NULL. |
| * \param ctx_handle Execution context. |
| * \param const_nds_handle The NDArrays that current operation will use |
| * but not mutate. |
| * \param num_const_nds The number of const_nds_handle. |
| * \param mutable_nds_handle The NDArrays that current operation will mutate. |
| * \param num_mutable_nds The number of mutable_nds_handle. |
| * \param prop_handle Property of the function. |
| * \param priority Priority of the action, as hint to the engine. |
| * \param opr_name The operation name. |
| */ |
| MXNET_DLL int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param, |
| EngineFuncParamDeleter deleter, ContextHandle ctx_handle, |
| NDArrayHandle* const_nds_handle, int num_const_nds, |
| NDArrayHandle* mutable_nds_handle, int num_mutable_nds, |
| EngineFnPropertyHandle prop_handle DEFAULT(NULL), |
| int priority DEFAULT(0), const char* opr_name DEFAULT(NULL)); |
| |
| #ifdef __cplusplus |
| } |
| #endif // __cplusplus |
| |
| #endif // MXNET_C_API_H_ |