blob: 09799f2b199139aa1e6771e5e79d8a2985127b35 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file iter_mnist.cc
* \brief register mnist iterator
*/
#include <mxnet/io.h>
#include <mxnet/base.h>
#include <dmlc/io.h>
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <string>
#include <vector>
#include <utility>
#include <map>
#include "./iter_prefetcher.h"
#include "../common/utils.h"
namespace mxnet {
namespace io {
// Define mnist io parameters
struct MNISTParam : public dmlc::Parameter<MNISTParam> {
/*! \brief path */
std::string image, label;
/*! \brief whether to do shuffle */
bool shuffle;
/*! \brief whether to print info */
bool silent;
/*! \brief batch size */
int batch_size;
/*! \brief data mode */
bool flat;
/*! \brief random seed */
int seed;
/*! \brief partition the data into multiple parts */
int num_parts;
/*! \brief the index of the part will read*/
int part_index;
// declare parameters
DMLC_DECLARE_PARAMETER(MNISTParam) {
DMLC_DECLARE_FIELD(image).set_default("./train-images-idx3-ubyte")
.describe("Dataset Param: Mnist image path.");
DMLC_DECLARE_FIELD(label).set_default("./train-labels-idx1-ubyte")
.describe("Dataset Param: Mnist label path.");
DMLC_DECLARE_FIELD(batch_size).set_lower_bound(1).set_default(128)
.describe("Batch Param: Batch Size.");
DMLC_DECLARE_FIELD(shuffle).set_default(true)
.describe("Augmentation Param: Whether to shuffle data.");
DMLC_DECLARE_FIELD(flat).set_default(false)
.describe("Augmentation Param: Whether to flat the data into 1D.");
DMLC_DECLARE_FIELD(seed).set_default(0)
.describe("Augmentation Param: Random Seed.");
DMLC_DECLARE_FIELD(silent).set_default(false)
.describe("Auxiliary Param: Whether to print out data info.");
DMLC_DECLARE_FIELD(num_parts).set_default(1)
.describe("partition the data into multiple parts");
DMLC_DECLARE_FIELD(part_index).set_default(0)
.describe("the index of the part will read");
}
};
class MNISTIter: public IIterator<TBlobBatch> {
public:
MNISTIter(void) : loc_(0), inst_offset_(0) {
img_.dptr_ = NULL;
out_.data.resize(2);
}
virtual ~MNISTIter(void) {
if (img_.dptr_ != NULL) delete []img_.dptr_;
}
// intialize iterator loads data in
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {
std::map<std::string, std::string> kmap(kwargs.begin(), kwargs.end());
param_.InitAllowUnknown(kmap);
this->LoadImage();
this->LoadLabel();
if (param_.flat) {
batch_data_.shape_ = mshadow::Shape4(param_.batch_size, 1, 1, img_.size(1) * img_.size(2));
} else {
batch_data_.shape_ = mshadow::Shape4(param_.batch_size, 1, img_.size(1), img_.size(2));
}
out_.data.clear();
batch_label_.shape_ = mshadow::Shape2(param_.batch_size, 1);
batch_label_.stride_ = 1;
batch_data_.stride_ = batch_data_.size(3);
out_.batch_size = param_.batch_size;
if (param_.shuffle) this->Shuffle();
if (param_.silent == 0) {
mshadow::TShape s;
s = batch_data_.shape_;
if (param_.flat) {
LOG(INFO) << "MNISTIter: load " << (unsigned)img_.size(0) << " images, shuffle="
<< param_.shuffle << ", shape=" << s.FlatTo2D();
} else {
LOG(INFO) << "MNISTIter: load " << (unsigned)img_.size(0) << " images, shuffle="
<< param_.shuffle << ", shape=" << s;
}
}
}
virtual void BeforeFirst(void) {
this->loc_ = 0;
}
virtual bool Next(void) {
if (loc_ + param_.batch_size <= img_.size(0)) {
batch_data_.dptr_ = img_[loc_].dptr_;
batch_label_.dptr_ = &labels_[loc_];
out_.data.clear();
if (param_.flat) {
out_.data.push_back(TBlob(batch_data_.FlatTo2D()));
} else {
out_.data.push_back(TBlob(batch_data_));
}
out_.data.push_back(TBlob(batch_label_));
loc_ += param_.batch_size;
return true;
} else {
return false;
}
}
virtual const TBlobBatch &Value(void) const {
return out_;
}
private:
inline void GetPart(int count, int* start, int *end) {
CHECK_GE(param_.part_index, 0);
CHECK_GT(param_.num_parts, 0);
CHECK_GT(param_.num_parts, param_.part_index);
*start = static_cast<int>(
static_cast<double>(count) / param_.num_parts * param_.part_index);
*end = static_cast<int>(
static_cast<double>(count) / param_.num_parts * (param_.part_index+1));
}
inline void LoadImage(void) {
dmlc::SeekStream* stdimg
= dmlc::SeekStream::CreateForRead(param_.image.c_str());
ReadInt(stdimg);
int image_count = ReadInt(stdimg);
int image_rows = ReadInt(stdimg);
int image_cols = ReadInt(stdimg);
int start, end;
GetPart(image_count, &start, &end);
image_count = end - start;
if (start > 0) {
stdimg->Seek(stdimg->Tell() + start * image_rows * image_cols);
}
img_.shape_ = mshadow::Shape3(image_count, image_rows, image_cols);
img_.stride_ = img_.size(2);
// allocate continuous memory
img_.dptr_ = new float[img_.MSize()];
for (int i = 0; i < image_count; ++i) {
for (int j = 0; j < image_rows; ++j) {
for (int k = 0; k < image_cols; ++k) {
unsigned char ch;
CHECK(stdimg->Read(&ch, sizeof(ch) != 0));
img_[i][j][k] = ch;
}
}
}
// normalize to 0-1
img_ *= 1.0f / 256.0f;
delete stdimg;
}
inline void LoadLabel(void) {
dmlc::SeekStream* stdlabel
= dmlc::SeekStream::CreateForRead(param_.label.c_str());
ReadInt(stdlabel);
int labels_count = ReadInt(stdlabel);
int start, end;
GetPart(labels_count, &start, &end);
labels_count = end - start;
if (start > 0) {
stdlabel->Seek(stdlabel->Tell() + start);
}
labels_.resize(labels_count);
for (int i = 0; i < labels_count; ++i) {
unsigned char ch;
CHECK(stdlabel->Read(&ch, sizeof(ch) != 0));
labels_[i] = ch;
inst_.push_back((unsigned)i + inst_offset_);
}
delete stdlabel;
}
inline void Shuffle(void) {
std::shuffle(inst_.begin(), inst_.end(), common::RANDOM_ENGINE(kRandMagic + param_.seed));
std::vector<float> tmplabel(labels_.size());
mshadow::TensorContainer<cpu, 3> tmpimg(img_.shape_);
for (size_t i = 0; i < inst_.size(); ++i) {
unsigned ridx = inst_[i] - inst_offset_;
mshadow::Copy(tmpimg[i], img_[ridx]);
tmplabel[i] = labels_[ridx];
}
// copy back
mshadow::Copy(img_, tmpimg);
labels_ = tmplabel;
}
private:
inline static int ReadInt(dmlc::Stream *fi) {
unsigned char buf[4];
CHECK(fi->Read(buf, sizeof(buf)) == sizeof(buf))
<< "invalid mnist format";
#ifdef _MSC_VER
return (buf[0] << 24 | buf[1] << 16 | buf[2] << 8 | buf[3]);
#else
return reinterpret_cast<int>(buf[0] << 24 | buf[1] << 16 | buf[2] << 8 | buf[3]);
#endif
}
private:
/*! \brief MNIST iter params */
MNISTParam param_;
/*! \brief output */
TBlobBatch out_;
/*! \brief current location */
index_t loc_;
/*! \brief image content */
mshadow::Tensor<cpu, 3> img_;
/*! \brief label content */
std::vector<float> labels_;
/*! \brief batch data tensor */
mshadow::Tensor<cpu, 4> batch_data_;
/*! \brief batch label tensor */
mshadow::Tensor<cpu, 2> batch_label_;
/*! \brief instance index offset */
unsigned inst_offset_;
/*! \brief instance index */
std::vector<unsigned> inst_;
// magic number to setup randomness
static const int kRandMagic = 0;
}; // class MNISTIter
DMLC_REGISTER_PARAMETER(MNISTParam);
MXNET_REGISTER_IO_ITER(MNISTIter)
.describe(R"code(Iterating on the MNIST dataset.
One can download the dataset from http://yann.lecun.com/exdb/mnist/
)code" ADD_FILELINE)
.add_arguments(MNISTParam::__FIELDS__())
.add_arguments(PrefetcherParam::__FIELDS__())
.set_body([]() {
return new PrefetcherIter(new MNISTIter());
});
} // namespace io
} // namespace mxnet