blob: d8857f80635d202f86bf7b074e55f35325b0e052 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file c_api_error.h
* \brief Error handling for C API.
*/
#ifndef MXNET_C_API_C_API_COMMON_H_
#define MXNET_C_API_C_API_COMMON_H_
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <mxnet/c_api.h>
#include <mxnet/base.h>
#include <nnvm/graph.h>
#include <vector>
#include <string>
/*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try {
/*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */
#define API_END() } catch(dmlc::Error &_except_) { return MXAPIHandleException(_except_); } return 0; // NOLINT(*)
/*!
* \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR
* The finally clause contains procedure to cleanup states when an error happens.
*/
#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return MXAPIHandleException(_except_); } return 0; // NOLINT(*)
/*!
* \brief Set the last error message needed by C API
* \param msg The error message to set.
*/
void MXAPISetLastError(const char* msg);
/*!
* \brief handle exception throwed out
* \param e the exception
* \return the return value of API after exception is handled
*/
inline int MXAPIHandleException(const dmlc::Error &e) {
MXAPISetLastError(e.what());
return -1;
}
using namespace mxnet;
/*! \brief entry to to easily hold returning information */
struct MXAPIThreadLocalEntry {
/*! \brief result holder for returning string */
std::string ret_str;
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
/*! \brief result holder for returning handles */
std::vector<void *> ret_handles;
/*! \brief result holder for returning shapes */
std::vector<TShape> arg_shapes, out_shapes, aux_shapes;
/*! \brief result holder for returning type flags */
std::vector<int> arg_types, out_types, aux_types;
/*! \brief result holder for returning shape dimensions */
std::vector<mx_uint> arg_shape_ndim, out_shape_ndim, aux_shape_ndim;
/*! \brief result holder for returning shape pointer */
std::vector<const mx_uint*> arg_shape_data, out_shape_data, aux_shape_data;
/*! \brief uint32_t buffer for returning shape pointer */
std::vector<uint32_t> arg_shape_buffer, out_shape_buffer, aux_shape_buffer;
// helper function to setup return value of shape array
inline static void SetupShapeArrayReturnWithBuffer(
const std::vector<TShape> &shapes,
std::vector<mx_uint> *ndim,
std::vector<const mx_uint*> *data,
std::vector<uint32_t> *buffer) {
ndim->resize(shapes.size());
data->resize(shapes.size());
size_t size = 0;
for (const auto& s : shapes) size += s.ndim();
buffer->resize(size);
uint32_t *ptr = buffer->data();
for (size_t i = 0; i < shapes.size(); ++i) {
ndim->at(i) = shapes[i].ndim();
data->at(i) = ptr;
ptr = nnvm::ShapeTypeCast(shapes[i].begin(), shapes[i].end(), ptr);
}
}
};
// define the threadlocal store.
typedef dmlc::ThreadLocalStore<MXAPIThreadLocalEntry> MXAPIThreadLocalStore;
namespace mxnet {
// copy attributes from inferred vector back to the vector of each type.
template<typename AttrType>
inline void CopyAttr(const nnvm::IndexedGraph& idx,
const std::vector<AttrType>& attr_vec,
std::vector<AttrType>* in_attr,
std::vector<AttrType>* out_attr,
std::vector<AttrType>* aux_attr) {
in_attr->clear();
out_attr->clear();
aux_attr->clear();
for (uint32_t nid : idx.input_nodes()) {
if (idx.mutable_input_nodes().count(nid) == 0) {
in_attr->push_back(attr_vec[idx.entry_id(nid, 0)]);
} else {
aux_attr->push_back(attr_vec[idx.entry_id(nid, 0)]);
}
}
for (auto& e : idx.outputs()) {
out_attr->push_back(attr_vec[idx.entry_id(e)]);
}
}
// stores keys that will be converted to __key__
extern const std::vector<std::string> kHiddenKeys;
} // namespace mxnet
#endif // MXNET_C_API_C_API_COMMON_H_