blob: b3032351fe7a32856b3eb56cdd955113d1075eb6 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file io.h
* \brief Rcpp Data Loading and Iteration Interface of MXNet.
*/
#ifndef MXNET_RCPP_IO_H_
#define MXNET_RCPP_IO_H_
#include <Rcpp.h>
#include <mxnet/c_api.h>
#include <string>
#include <vector>
#include "./base.h"
#include "./ndarray.h"
namespace mxnet {
namespace R {
// creator function of DataIter
class DataIterCreateFunction;
/*! \brief Base iterator interface */
class DataIter {
public:
virtual ~DataIter() {}
/*! \return typename from R side. */
inline static const char* TypeName() {
return "DataIter";
}
/*! \brief Reset the iterator */
virtual void Reset() = 0;
/*!
* \brief Move to next position.
* \return whether the move is successful.
*/
virtual bool Next() = 0;
/*!
* \brief number of padding examples.
* \return number of padding examples.
*/
virtual int NumPad() const = 0;
/*!
* \brief Get the Data Element
* \return List of NDArray of elements in this value.
*/
virtual Rcpp::List Value() const = 0;
/*! \brief initialize the R cpp Module */
static void InitRcppModule();
};
/*!
* \brief MXNet's internal data iterator.
*/
class MXDataIter : public DataIter {
public:
/*! \return typename from R side. */
inline static const char* TypeName() {
return "MXNativeDataIter";
}
// implement the interface
virtual void Reset();
virtual bool Next();
virtual int NumPad() const;
virtual Rcpp::List Value() const;
virtual ~MXDataIter() {
MX_CALL(MXDataIterFree(handle_));
}
private:
friend class DataIter;
friend class DataIterCreateFunction;
// constructor
MXDataIter() {}
explicit MXDataIter(DataIterHandle handle)
: handle_(handle) {}
/*!
* \brief create a R object that correspond to the Class
* \param handle the Handle needed for output.
*/
inline static Rcpp::RObject RObject(DataIterHandle handle) {
return Rcpp::internal::make_new_object(new MXDataIter(handle));
}
/*! \brief internal data iter handle */
DataIterHandle handle_;
};
/*!
* \brief data iterator that takes a NumericVector
* Shuffles it and iterate over its content.
* TODO(KK, tq) implement this when have time.
* c.f. python/io.py:NDArrayIter
*/
class ArrayDataIter : public DataIter {
public:
/*! \return typename from R side. */
inline static const char* TypeName() {
return "MXArrayDataIter";
}
/*!
* \brief Construct a ArrayDataIter from data and label.
* \param data The data array.
* \param label The label array.
* \param unif_rnds Uniform [0,1] random number of same length as label.
* Only needed when shuffle=TRUE
* \param batch_size The size of the batch.
* \param shuffle Whether shuffle the data.
*/
ArrayDataIter(const Rcpp::NumericVector& data,
const Rcpp::NumericVector& label,
const Rcpp::NumericVector& unif_rnds,
int batch_size,
bool shuffle);
virtual void Reset() {
counter_ = 0;
}
virtual bool Next();
virtual int NumPad() const;
virtual Rcpp::List Value() const;
static Rcpp::RObject Create(const Rcpp::NumericVector& data,
const Rcpp::NumericVector& label,
const Rcpp::NumericVector& unif_rnds,
int batch_size,
bool shuffle);
private:
friend class DataIter;
// create internal representation
static void Convert(const Rcpp::NumericVector &src,
const std::vector<size_t> &order,
size_t batch_size,
std::vector<NDArray> *out);
/*! \brief The counter */
size_t counter_;
/*! \brief number of pad instances*/
size_t num_pad_;
/*! \brief number of data */
size_t num_data;
/*! \brief The data list of each batch */
std::vector<NDArray> data_;
/*! \brief The data list of each batch */
std::vector<NDArray> label_;
};
/*! \brief The DataIterCreate functions to be invoked */
class DataIterCreateFunction : public ::Rcpp::CppFunction {
public:
virtual SEXP operator() (SEXP* args);
virtual int nargs() {
return 1;
}
virtual bool is_void() {
return false;
}
virtual void signature(std::string& s, const char* name) { // NOLINT(*)
::Rcpp::signature< SEXP, ::Rcpp::List >(s, name);
}
virtual const char* get_name() {
return name_.c_str();
}
virtual SEXP get_formals() {
return Rcpp::List::create(Rcpp::_["alist"]);
}
virtual DL_FUNC get_function_ptr() {
return (DL_FUNC)NULL; // NOLINT(*)
}
/*! \brief static function to initialize the Rcpp functions */
static void InitRcppModule();
private:
// make constructor private
explicit DataIterCreateFunction(DataIterCreator handle);
/*! \brief internal creator handle. */
DataIterCreator handle_;
// name of the function
std::string name_;
};
} // namespace R
} // namespace mxnet
RCPP_EXPOSED_CLASS_NODECL(::mxnet::R::MXDataIter);
RCPP_EXPOSED_CLASS_NODECL(::mxnet::R::ArrayDataIter);
namespace Rcpp {
template<>
inline bool is<mxnet::R::MXDataIter>(SEXP x) {
return internal::is__module__object_fix<mxnet::R::MXDataIter>(x);
}
template<>
inline bool is<mxnet::R::ArrayDataIter>(SEXP x) {
return internal::is__module__object_fix<mxnet::R::ArrayDataIter>(x);
}
// This patch need to be kept even after the Rcpp update merged in.
template<>
inline bool is<mxnet::R::DataIter>(SEXP x) {
return is<mxnet::R::MXDataIter>(x) ||
is<mxnet::R::ArrayDataIter>(x);
}
} // namespace Rcpp
#endif // MXNET_RCPP_IO_H_