| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file iter_mnist.cc |
| * \brief register mnist iterator |
| */ |
| #include <mxnet/io.h> |
| #include <mxnet/base.h> |
| #include <dmlc/io.h> |
| #include <dmlc/logging.h> |
| #include <dmlc/parameter.h> |
| #include <string> |
| #include <vector> |
| #include <utility> |
| #include <map> |
| #include "./iter_prefetcher.h" |
| #include "../common/utils.h" |
| |
| namespace mxnet { |
| namespace io { |
| // Define mnist io parameters |
| struct MNISTParam : public dmlc::Parameter<MNISTParam> { |
| /*! \brief path */ |
| std::string image, label; |
| /*! \brief whether to do shuffle */ |
| bool shuffle; |
| /*! \brief whether to print info */ |
| bool silent; |
| /*! \brief batch size */ |
| int batch_size; |
| /*! \brief data mode */ |
| bool flat; |
| /*! \brief random seed */ |
| int seed; |
| /*! \brief partition the data into multiple parts */ |
| int num_parts; |
| /*! \brief the index of the part will read*/ |
| int part_index; |
| // declare parameters |
| DMLC_DECLARE_PARAMETER(MNISTParam) { |
| DMLC_DECLARE_FIELD(image) |
| .set_default("./train-images-idx3-ubyte") |
| .describe("Dataset Param: Mnist image path."); |
| DMLC_DECLARE_FIELD(label) |
| .set_default("./train-labels-idx1-ubyte") |
| .describe("Dataset Param: Mnist label path."); |
| DMLC_DECLARE_FIELD(batch_size) |
| .set_lower_bound(1) |
| .set_default(128) |
| .describe("Batch Param: Batch Size."); |
| DMLC_DECLARE_FIELD(shuffle).set_default(true).describe( |
| "Augmentation Param: Whether to shuffle data."); |
| DMLC_DECLARE_FIELD(flat).set_default(false).describe( |
| "Augmentation Param: Whether to flat the data into 1D."); |
| DMLC_DECLARE_FIELD(seed).set_default(0).describe("Augmentation Param: Random Seed."); |
| DMLC_DECLARE_FIELD(silent).set_default(false).describe( |
| "Auxiliary Param: Whether to print out data info."); |
| DMLC_DECLARE_FIELD(num_parts).set_default(1).describe("partition the data into multiple parts"); |
| DMLC_DECLARE_FIELD(part_index).set_default(0).describe("the index of the part will read"); |
| } |
| }; |
| |
| class MNISTIter : public IIterator<TBlobBatch> { |
| public: |
| MNISTIter() { |
| img_.dptr_ = nullptr; |
| out_.data.resize(2); |
| } |
| ~MNISTIter() override { |
| delete[] img_.dptr_; |
| } |
| // intialize iterator loads data in |
| void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { |
| std::map<std::string, std::string> kmap(kwargs.begin(), kwargs.end()); |
| param_.InitAllowUnknown(kmap); |
| this->LoadImage(); |
| this->LoadLabel(); |
| if (param_.flat) { |
| batch_data_.shape_ = mshadow::Shape4(param_.batch_size, 1, 1, img_.size(1) * img_.size(2)); |
| } else { |
| batch_data_.shape_ = mshadow::Shape4(param_.batch_size, 1, img_.size(1), img_.size(2)); |
| } |
| out_.data.clear(); |
| batch_label_.shape_ = mshadow::Shape2(param_.batch_size, 1); |
| batch_label_.stride_ = 1; |
| batch_data_.stride_ = batch_data_.size(3); |
| out_.batch_size = param_.batch_size; |
| if (param_.shuffle) |
| this->Shuffle(); |
| if (param_.silent == 0) { |
| mxnet::TShape s; |
| s = batch_data_.shape_; |
| if (param_.flat) { |
| LOG(INFO) << "MNISTIter: load " << (unsigned)img_.size(0) |
| << " images, shuffle=" << param_.shuffle << ", shape=" << s.FlatTo2D(); |
| } else { |
| LOG(INFO) << "MNISTIter: load " << (unsigned)img_.size(0) |
| << " images, shuffle=" << param_.shuffle << ", shape=" << s; |
| } |
| } |
| } |
| void BeforeFirst() override { |
| this->loc_ = 0; |
| } |
| bool Next() override { |
| if (loc_ + param_.batch_size <= img_.size(0)) { |
| batch_data_.dptr_ = img_[loc_].dptr_; |
| batch_label_.dptr_ = &labels_[loc_]; |
| out_.data.clear(); |
| if (param_.flat) { |
| out_.data.emplace_back(batch_data_.FlatTo2D()); |
| } else { |
| out_.data.emplace_back(batch_data_); |
| } |
| out_.data.emplace_back(batch_label_); |
| loc_ += param_.batch_size; |
| return true; |
| } else { |
| return false; |
| } |
| } |
| const TBlobBatch& Value() const override { |
| return out_; |
| } |
| |
| private: |
| inline void GetPart(int count, int* start, int* end) { |
| CHECK_GE(param_.part_index, 0); |
| CHECK_GT(param_.num_parts, 0); |
| CHECK_GT(param_.num_parts, param_.part_index); |
| |
| *start = static_cast<int>(static_cast<double>(count) / param_.num_parts * param_.part_index); |
| *end = |
| static_cast<int>(static_cast<double>(count) / param_.num_parts * (param_.part_index + 1)); |
| } |
| |
| inline void LoadImage() { |
| dmlc::SeekStream* stdimg = dmlc::SeekStream::CreateForRead(param_.image.c_str()); |
| ReadInt(stdimg); |
| int image_count = ReadInt(stdimg); |
| int image_rows = ReadInt(stdimg); |
| int image_cols = ReadInt(stdimg); |
| |
| int start, end; |
| GetPart(image_count, &start, &end); |
| image_count = end - start; |
| if (start > 0) { |
| stdimg->Seek(stdimg->Tell() + start * image_rows * image_cols); |
| } |
| |
| img_.shape_ = mshadow::Shape3(image_count, image_rows, image_cols); |
| img_.stride_ = img_.size(2); |
| |
| // allocate continuous memory |
| img_.dptr_ = new float[img_.MSize()]; |
| for (int i = 0; i < image_count; ++i) { |
| for (int j = 0; j < image_rows; ++j) { |
| for (int k = 0; k < image_cols; ++k) { |
| unsigned char ch; |
| CHECK(stdimg->Read(&ch, sizeof(ch) != 0)); |
| img_[i][j][k] = ch; |
| } |
| } |
| } |
| // normalize to 0-1 |
| img_ *= 1.0f / 256.0f; |
| delete stdimg; |
| } |
| inline void LoadLabel() { |
| dmlc::SeekStream* stdlabel = dmlc::SeekStream::CreateForRead(param_.label.c_str()); |
| ReadInt(stdlabel); |
| int labels_count = ReadInt(stdlabel); |
| |
| int start, end; |
| GetPart(labels_count, &start, &end); |
| labels_count = end - start; |
| if (start > 0) { |
| stdlabel->Seek(stdlabel->Tell() + start); |
| } |
| |
| labels_.resize(labels_count); |
| for (int i = 0; i < labels_count; ++i) { |
| unsigned char ch; |
| CHECK(stdlabel->Read(&ch, sizeof(ch) != 0)); |
| labels_[i] = ch; |
| inst_.push_back((unsigned)i + inst_offset_); |
| } |
| delete stdlabel; |
| } |
| inline void Shuffle() { |
| std::shuffle(inst_.begin(), inst_.end(), common::RANDOM_ENGINE(kRandMagic + param_.seed)); |
| std::vector<float> tmplabel(labels_.size()); |
| mshadow::TensorContainer<cpu, 3> tmpimg(img_.shape_); |
| for (size_t i = 0; i < inst_.size(); ++i) { |
| unsigned ridx = inst_[i] - inst_offset_; |
| mshadow::Copy(tmpimg[i], img_[ridx]); |
| tmplabel[i] = labels_[ridx]; |
| } |
| // copy back |
| mshadow::Copy(img_, tmpimg); |
| labels_ = tmplabel; |
| } |
| |
| private: |
| inline static int ReadInt(dmlc::Stream* fi) { |
| unsigned char buf[4]; |
| CHECK(fi->Read(buf, sizeof(buf)) == sizeof(buf)) << "invalid mnist format"; |
| #ifdef _MSC_VER |
| return (buf[0] << 24 | buf[1] << 16 | buf[2] << 8 | buf[3]); |
| #else |
| return reinterpret_cast<int>(buf[0] << 24 | buf[1] << 16 | buf[2] << 8 | buf[3]); |
| #endif |
| } |
| |
| private: |
| /*! \brief MNIST iter params */ |
| MNISTParam param_; |
| /*! \brief output */ |
| TBlobBatch out_; |
| /*! \brief current location */ |
| index_t loc_{0}; |
| /*! \brief image content */ |
| mshadow::Tensor<cpu, 3> img_; |
| /*! \brief label content */ |
| std::vector<float> labels_; |
| /*! \brief batch data tensor */ |
| mshadow::Tensor<cpu, 4> batch_data_; |
| /*! \brief batch label tensor */ |
| mshadow::Tensor<cpu, 2> batch_label_; |
| /*! \brief instance index offset */ |
| unsigned inst_offset_{0}; |
| /*! \brief instance index */ |
| std::vector<unsigned> inst_; |
| // magic number to setup randomness |
| static const int kRandMagic = 0; |
| }; // class MNISTIter |
| |
| DMLC_REGISTER_PARAMETER(MNISTParam); |
| |
| MXNET_REGISTER_IO_ITER(MNISTIter) |
| .describe("Iterating on the MNIST dataset." ADD_FILELINE) |
| .add_arguments(MNISTParam::__FIELDS__()) |
| .add_arguments(PrefetcherParam::__FIELDS__()) |
| .set_body([]() { return new PrefetcherIter(new MNISTIter()); }); |
| |
| } // namespace io |
| } // namespace mxnet |