blob: 5f859b3d2bfe0171b1e9182ab6f079e57c0298b9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file iter_prefetcher.h
* \brief define a prefetcher using threaditer to keep k batch fetched
*/
#ifndef MXNET_IO_ITER_PREFETCHER_H_
#define MXNET_IO_ITER_PREFETCHER_H_
#include <mxnet/io.h>
#include <mxnet/base.h>
#include <mxnet/ndarray.h>
#include <dmlc/logging.h>
#include <dmlc/threadediter.h>
#include <dmlc/optional.h>
#include <mshadow/tensor.h>
#include <climits>
#include <utility>
#include <string>
#include <vector>
#include <queue>
#include <algorithm>
#include "./inst_vector.h"
#include "./image_iter_common.h"
namespace mxnet {
namespace io {
// iterator on image recordio
class PrefetcherIter : public IIterator<DataBatch> {
public:
explicit PrefetcherIter(IIterator<TBlobBatch>* base)
: loader_(base), out_(nullptr), length_hint_(-1) {}
~PrefetcherIter() {
while (recycle_queue_.size() != 0) {
DataBatch* batch = recycle_queue_.front();
recycle_queue_.pop();
delete batch;
}
delete out_;
iter.Destroy();
}
void InitParams(const std::vector<std::pair<std::string, std::string> >& kwargs) {
std::vector<std::pair<std::string, std::string> > kwargs_left;
// init image rec param
kwargs_left = param_.InitAllowUnknown(kwargs);
CHECK_GT(param_.prefetch_buffer, 0) << "Prefetch_buffer must be positive number";
// maximum prefetch threaded iter internal size
const int kMaxPrefetchBuffer = 16;
// init thread iter
iter.set_max_capacity(kMaxPrefetchBuffer);
}
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {
InitParams(kwargs);
// use the kwarg to init batch loader
loader_->Init(kwargs);
length_hint_ = loader_->GetLenHint();
iter.Init(
[this](DataBatch** dptr) {
if (!loader_->Next())
return false;
const TBlobBatch& batch = loader_->Value();
if (*dptr == nullptr) {
// allocate databatch
*dptr = new DataBatch();
(*dptr)->num_batch_padd = batch.num_batch_padd;
(*dptr)->data.resize(batch.data.size());
(*dptr)->index.resize(batch.batch_size);
for (size_t i = 0; i < batch.data.size(); ++i) {
auto dtype = param_.dtype ? param_.dtype.value() : batch.data[i].type_flag_;
auto ctx = ((param_.ctx == PrefetcherParam::kCPUPinned) && (param_.device_id >= 0)) ?
Context::CPUPinned(param_.device_id) :
Context::CPU();
(*dptr)->data.at(i) = NDArray(batch.data[i].shape_, ctx, false, dtype);
}
}
CHECK(batch.data.size() == (*dptr)->data.size());
// copy data over
for (size_t i = 0; i < batch.data.size(); ++i) {
if ((*dptr)->data.at(i).shape() != batch.data[i].shape_) {
// TODO(zhreshold): memory pool for dynamic shaped data
(*dptr)->data.at(i).ReshapeAndAlloc(batch.data[i].shape_);
}
CHECK_EQ((*dptr)->data.at(i).shape(), batch.data[i].shape_);
MSHADOW_TYPE_SWITCH(batch.data[i].type_flag_, DType, {
mshadow::Copy(((*dptr)->data)[i].data().FlatTo2D<cpu, DType>(),
batch.data[i].FlatTo2D<cpu, DType>());
});
(*dptr)->num_batch_padd = batch.num_batch_padd;
}
if (batch.inst_index) {
std::copy(
batch.inst_index, batch.inst_index + batch.batch_size, (*dptr)->index.begin());
}
return true;
},
[this]() {
loader_->BeforeFirst();
length_hint_ = loader_->GetLenHint();
});
}
virtual void BeforeFirst(void) {
iter.BeforeFirst();
}
virtual int64_t GetLenHint(void) const {
return length_hint_;
}
virtual bool Next(void) {
if (out_ != nullptr) {
recycle_queue_.push(out_);
out_ = nullptr;
}
// do recycle
if (recycle_queue_.size() == param_.prefetch_buffer) {
DataBatch* old_batch = recycle_queue_.front();
// can be more efficient on engine
for (NDArray& arr : old_batch->data) {
arr.WaitToWrite();
}
recycle_queue_.pop();
iter.Recycle(&old_batch);
}
return iter.Next(&out_);
}
virtual const DataBatch& Value(void) const {
return *out_;
}
protected:
/*! \brief prefetcher parameters */
PrefetcherParam param_;
/*! \brief backend thread */
dmlc::ThreadedIter<DataBatch> iter;
/*! \brief internal batch loader */
std::unique_ptr<IIterator<TBlobBatch> > loader_;
private:
/*! \brief output data */
DataBatch* out_;
/*! \brief queue to be recycled */
std::queue<DataBatch*> recycle_queue_;
/*! \brief size hint cache */
int64_t length_hint_;
};
} // namespace io
} // namespace mxnet
#endif // MXNET_IO_ITER_PREFETCHER_H_