| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file operator.hpp |
| * \brief implementation of data iter |
| * \author Zhang Chen |
| */ |
| #ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_IO_HPP_ |
| #define CPP_PACKAGE_INCLUDE_MXNET_CPP_IO_HPP_ |
| |
| #include <string> |
| #include <vector> |
| #include "mxnet-cpp/io.h" |
| |
| namespace mxnet { |
| namespace cpp { |
| |
| inline MXDataIterMap*& MXDataIter::mxdataiter_map() { |
| static MXDataIterMap* mxdataiter_map_ = new MXDataIterMap; |
| return mxdataiter_map_; |
| } |
| |
| inline MXDataIter::MXDataIter(const std::string &mxdataiter_type) { |
| creator_ = mxdataiter_map()->GetMXDataIterCreator(mxdataiter_type); |
| blob_ptr_ = std::make_shared<MXDataIterBlob>(nullptr); |
| } |
| |
| inline void MXDataIter::BeforeFirst() { |
| int r = MXDataIterBeforeFirst(blob_ptr_->handle_); |
| CHECK_EQ(r, 0); |
| } |
| |
| inline bool MXDataIter::Next() { |
| int out; |
| int r = MXDataIterNext(blob_ptr_->handle_, &out); |
| CHECK_EQ(r, 0); |
| return out; |
| } |
| |
| inline NDArray MXDataIter::GetData() { |
| NDArrayHandle handle; |
| int r = MXDataIterGetData(blob_ptr_->handle_, &handle); |
| CHECK_EQ(r, 0); |
| return NDArray(handle); |
| } |
| |
| inline NDArray MXDataIter::GetLabel() { |
| NDArrayHandle handle; |
| int r = MXDataIterGetLabel(blob_ptr_->handle_, &handle); |
| CHECK_EQ(r, 0); |
| return NDArray(handle); |
| } |
| |
| inline int MXDataIter::GetPadNum() { |
| int out; |
| int r = MXDataIterGetPadNum(blob_ptr_->handle_, &out); |
| CHECK_EQ(r, 0); |
| return out; |
| } |
| inline std::vector<int> MXDataIter::GetIndex() { |
| uint64_t *out_index, out_size; |
| int r = MXDataIterGetIndex(blob_ptr_->handle_, &out_index, &out_size); |
| CHECK_EQ(r, 0); |
| std::vector<int> ret; |
| for (uint64_t i = 0; i < out_size; ++i) { |
| ret.push_back(out_index[i]); |
| } |
| return ret; |
| } |
| |
| inline MXDataIter MXDataIter::CreateDataIter() { |
| std::vector<const char *> param_keys; |
| std::vector<const char *> param_values; |
| |
| for (auto &data : params_) { |
| param_keys.push_back(data.first.c_str()); |
| param_values.push_back(data.second.c_str()); |
| } |
| |
| MXDataIterCreateIter(creator_, param_keys.size(), param_keys.data(), |
| param_values.data(), &blob_ptr_->handle_); |
| return *this; |
| } |
| |
| // MXDataIter MNIst |
| |
| } // namespace cpp |
| } // namespace mxnet |
| |
| #endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_IO_HPP_ |
| |