blob: 9050ef2d1b383038c9df706e20318cfc22cdee74 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file iter_prefetcher.h
* \brief define a prefetcher using threaditer to keep k batch fetched
*/
#ifndef MXNET_IO_ITER_PREFETCHER_H_
#define MXNET_IO_ITER_PREFETCHER_H_
#include <mxnet/io.h>
#include <mxnet/base.h>
#include <mxnet/ndarray.h>
#include <dmlc/logging.h>
#include <dmlc/threadediter.h>
#include <dmlc/optional.h>
#include <mshadow/tensor.h>
#include <climits>
#include <utility>
#include <string>
#include <vector>
#include <queue>
#include <algorithm>
#include "./inst_vector.h"
#include "./image_iter_common.h"
namespace mxnet {
namespace io {
// iterator on image recordio
class PrefetcherIter : public IIterator<DataBatch> {
public:
explicit PrefetcherIter(IIterator<TBlobBatch>* base)
: loader_(base), out_(nullptr) {
}
~PrefetcherIter() {
while (recycle_queue_.size() != 0) {
DataBatch *batch = recycle_queue_.front();
recycle_queue_.pop();
delete batch;
}
delete out_;
iter_.Destroy();
}
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {
std::vector<std::pair<std::string, std::string> > kwargs_left;
// init image rec param
kwargs_left = param_.InitAllowUnknown(kwargs);
// use the kwarg to init batch loader
loader_->Init(kwargs);
// maximum prefetch threaded iter internal size
const int kMaxPrefetchBuffer = 16;
// init thread iter
iter_.set_max_capacity(kMaxPrefetchBuffer);
iter_.Init([this](DataBatch **dptr) {
if (!loader_->Next()) return false;
const TBlobBatch& batch = loader_->Value();
if (*dptr == nullptr) {
// allocate databatch
*dptr = new DataBatch();
(*dptr)->num_batch_padd = batch.num_batch_padd;
(*dptr)->data.resize(batch.data.size());
(*dptr)->index.resize(batch.batch_size);
for (size_t i = 0; i < batch.data.size(); ++i) {
auto dtype = param_.dtype
? param_.dtype.value()
: batch.data[i].type_flag_;
(*dptr)->data.at(i) = NDArray(batch.data[i].shape_,
Context::CPU(), false,
dtype);
}
}
CHECK(batch.data.size() == (*dptr)->data.size());
// copy data over
for (size_t i = 0; i < batch.data.size(); ++i) {
CHECK_EQ((*dptr)->data.at(i).shape(), batch.data[i].shape_);
MSHADOW_TYPE_SWITCH(batch.data[i].type_flag_, DType, {
mshadow::Copy(((*dptr)->data)[i].data().FlatTo2D<cpu, DType>(),
batch.data[i].FlatTo2D<cpu, DType>());
});
(*dptr)->num_batch_padd = batch.num_batch_padd;
}
if (batch.inst_index) {
std::copy(batch.inst_index,
batch.inst_index + batch.batch_size,
(*dptr)->index.begin());
}
return true;
},
[this]() { loader_->BeforeFirst(); });
}
virtual void BeforeFirst(void) {
iter_.BeforeFirst();
}
virtual bool Next(void) {
if (out_ != nullptr) {
recycle_queue_.push(out_); out_ = nullptr;
}
// do recycle
if (recycle_queue_.size() == param_.prefetch_buffer) {
DataBatch *old_batch = recycle_queue_.front();
// can be more efficient on engine
for (NDArray& arr : old_batch->data) {
arr.WaitToWrite();
}
recycle_queue_.pop();
iter_.Recycle(&old_batch);
}
return iter_.Next(&out_);
}
virtual const DataBatch &Value(void) const {
return *out_;
}
protected:
/*! \brief prefetcher parameters */
PrefetcherParam param_;
/*! \brief internal batch loader */
std::unique_ptr<IIterator<TBlobBatch> > loader_;
private:
/*! \brief output data */
DataBatch *out_;
/*! \brief queue to be recycled */
std::queue<DataBatch*> recycle_queue_;
/*! \brief backend thread */
dmlc::ThreadedIter<DataBatch> iter_;
};
} // namespace io
} // namespace mxnet
#endif // MXNET_IO_ITER_PREFETCHER_H_