blob: fa5e0072b7b15544824d88ac323925cdf0f611ff [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file iter_mnist.cc
* \brief register mnist iterator
*/
#include <mxnet/io.h>
#include <mxnet/base.h>
#include <dmlc/io.h>
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <string>
#include <vector>
#include <utility>
#include <map>
#include "./iter_prefetcher.h"
#include "../common/utils.h"
namespace mxnet {
namespace io {
// Define mnist io parameters
struct MNISTParam : public dmlc::Parameter<MNISTParam> {
/*! \brief path */
std::string image, label;
/*! \brief whether to do shuffle */
bool shuffle;
/*! \brief whether to print info */
bool silent;
/*! \brief batch size */
int batch_size;
/*! \brief data mode */
bool flat;
/*! \brief random seed */
int seed;
/*! \brief partition the data into multiple parts */
int num_parts;
/*! \brief the index of the part will read*/
int part_index;
// declare parameters
DMLC_DECLARE_PARAMETER(MNISTParam) {
DMLC_DECLARE_FIELD(image)
.set_default("./train-images-idx3-ubyte")
.describe("Dataset Param: Mnist image path.");
DMLC_DECLARE_FIELD(label)
.set_default("./train-labels-idx1-ubyte")
.describe("Dataset Param: Mnist label path.");
DMLC_DECLARE_FIELD(batch_size)
.set_lower_bound(1)
.set_default(128)
.describe("Batch Param: Batch Size.");
DMLC_DECLARE_FIELD(shuffle).set_default(true).describe(
"Augmentation Param: Whether to shuffle data.");
DMLC_DECLARE_FIELD(flat).set_default(false).describe(
"Augmentation Param: Whether to flat the data into 1D.");
DMLC_DECLARE_FIELD(seed).set_default(0).describe("Augmentation Param: Random Seed.");
DMLC_DECLARE_FIELD(silent).set_default(false).describe(
"Auxiliary Param: Whether to print out data info.");
DMLC_DECLARE_FIELD(num_parts).set_default(1).describe("partition the data into multiple parts");
DMLC_DECLARE_FIELD(part_index).set_default(0).describe("the index of the part will read");
}
};
class MNISTIter : public IIterator<TBlobBatch> {
public:
MNISTIter() {
img_.dptr_ = nullptr;
out_.data.resize(2);
}
~MNISTIter() override {
delete[] img_.dptr_;
}
// intialize iterator loads data in
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
std::map<std::string, std::string> kmap(kwargs.begin(), kwargs.end());
param_.InitAllowUnknown(kmap);
this->LoadImage();
this->LoadLabel();
if (param_.flat) {
batch_data_.shape_ = mshadow::Shape4(param_.batch_size, 1, 1, img_.size(1) * img_.size(2));
} else {
batch_data_.shape_ = mshadow::Shape4(param_.batch_size, 1, img_.size(1), img_.size(2));
}
out_.data.clear();
batch_label_.shape_ = mshadow::Shape2(param_.batch_size, 1);
batch_label_.stride_ = 1;
batch_data_.stride_ = batch_data_.size(3);
out_.batch_size = param_.batch_size;
if (param_.shuffle)
this->Shuffle();
if (param_.silent == 0) {
mxnet::TShape s;
s = batch_data_.shape_;
if (param_.flat) {
LOG(INFO) << "MNISTIter: load " << (unsigned)img_.size(0)
<< " images, shuffle=" << param_.shuffle << ", shape=" << s.FlatTo2D();
} else {
LOG(INFO) << "MNISTIter: load " << (unsigned)img_.size(0)
<< " images, shuffle=" << param_.shuffle << ", shape=" << s;
}
}
}
void BeforeFirst() override {
this->loc_ = 0;
}
bool Next() override {
if (loc_ + param_.batch_size <= img_.size(0)) {
batch_data_.dptr_ = img_[loc_].dptr_;
batch_label_.dptr_ = &labels_[loc_];
out_.data.clear();
if (param_.flat) {
out_.data.emplace_back(batch_data_.FlatTo2D());
} else {
out_.data.emplace_back(batch_data_);
}
out_.data.emplace_back(batch_label_);
loc_ += param_.batch_size;
return true;
} else {
return false;
}
}
const TBlobBatch& Value() const override {
return out_;
}
private:
inline void GetPart(int count, int* start, int* end) {
CHECK_GE(param_.part_index, 0);
CHECK_GT(param_.num_parts, 0);
CHECK_GT(param_.num_parts, param_.part_index);
*start = static_cast<int>(static_cast<double>(count) / param_.num_parts * param_.part_index);
*end =
static_cast<int>(static_cast<double>(count) / param_.num_parts * (param_.part_index + 1));
}
inline void LoadImage() {
dmlc::SeekStream* stdimg = dmlc::SeekStream::CreateForRead(param_.image.c_str());
ReadInt(stdimg);
int image_count = ReadInt(stdimg);
int image_rows = ReadInt(stdimg);
int image_cols = ReadInt(stdimg);
int start, end;
GetPart(image_count, &start, &end);
image_count = end - start;
if (start > 0) {
stdimg->Seek(stdimg->Tell() + start * image_rows * image_cols);
}
img_.shape_ = mshadow::Shape3(image_count, image_rows, image_cols);
img_.stride_ = img_.size(2);
// allocate continuous memory
img_.dptr_ = new float[img_.MSize()];
for (int i = 0; i < image_count; ++i) {
for (int j = 0; j < image_rows; ++j) {
for (int k = 0; k < image_cols; ++k) {
unsigned char ch;
CHECK(stdimg->Read(&ch, sizeof(ch) != 0));
img_[i][j][k] = ch;
}
}
}
// normalize to 0-1
img_ *= 1.0f / 256.0f;
delete stdimg;
}
inline void LoadLabel() {
dmlc::SeekStream* stdlabel = dmlc::SeekStream::CreateForRead(param_.label.c_str());
ReadInt(stdlabel);
int labels_count = ReadInt(stdlabel);
int start, end;
GetPart(labels_count, &start, &end);
labels_count = end - start;
if (start > 0) {
stdlabel->Seek(stdlabel->Tell() + start);
}
labels_.resize(labels_count);
for (int i = 0; i < labels_count; ++i) {
unsigned char ch;
CHECK(stdlabel->Read(&ch, sizeof(ch) != 0));
labels_[i] = ch;
inst_.push_back((unsigned)i + inst_offset_);
}
delete stdlabel;
}
inline void Shuffle() {
std::shuffle(inst_.begin(), inst_.end(), common::RANDOM_ENGINE(kRandMagic + param_.seed));
std::vector<float> tmplabel(labels_.size());
mshadow::TensorContainer<cpu, 3> tmpimg(img_.shape_);
for (size_t i = 0; i < inst_.size(); ++i) {
unsigned ridx = inst_[i] - inst_offset_;
mshadow::Copy(tmpimg[i], img_[ridx]);
tmplabel[i] = labels_[ridx];
}
// copy back
mshadow::Copy(img_, tmpimg);
labels_ = tmplabel;
}
private:
inline static int ReadInt(dmlc::Stream* fi) {
unsigned char buf[4];
CHECK(fi->Read(buf, sizeof(buf)) == sizeof(buf)) << "invalid mnist format";
#ifdef _MSC_VER
return (buf[0] << 24 | buf[1] << 16 | buf[2] << 8 | buf[3]);
#else
return reinterpret_cast<int>(buf[0] << 24 | buf[1] << 16 | buf[2] << 8 | buf[3]);
#endif
}
private:
/*! \brief MNIST iter params */
MNISTParam param_;
/*! \brief output */
TBlobBatch out_;
/*! \brief current location */
index_t loc_{0};
/*! \brief image content */
mshadow::Tensor<cpu, 3> img_;
/*! \brief label content */
std::vector<float> labels_;
/*! \brief batch data tensor */
mshadow::Tensor<cpu, 4> batch_data_;
/*! \brief batch label tensor */
mshadow::Tensor<cpu, 2> batch_label_;
/*! \brief instance index offset */
unsigned inst_offset_{0};
/*! \brief instance index */
std::vector<unsigned> inst_;
// magic number to setup randomness
static const int kRandMagic = 0;
}; // class MNISTIter
DMLC_REGISTER_PARAMETER(MNISTParam);
MXNET_REGISTER_IO_ITER(MNISTIter)
.describe("Iterating on the MNIST dataset." ADD_FILELINE)
.add_arguments(MNISTParam::__FIELDS__())
.add_arguments(PrefetcherParam::__FIELDS__())
.set_body([]() { return new PrefetcherIter(new MNISTIter()); });
} // namespace io
} // namespace mxnet