| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file iter_prefetcher.h |
| * \brief define a prefetcher using threaditer to keep k batch fetched |
| */ |
| #ifndef MXNET_IO_ITER_PREFETCHER_H_ |
| #define MXNET_IO_ITER_PREFETCHER_H_ |
| |
| #include <mxnet/io.h> |
| #include <mxnet/base.h> |
| #include <mxnet/ndarray.h> |
| #include <dmlc/logging.h> |
| #include <dmlc/threadediter.h> |
| #include <dmlc/optional.h> |
| #include <mshadow/tensor.h> |
| #include <climits> |
| #include <utility> |
| #include <string> |
| #include <vector> |
| #include <queue> |
| #include <algorithm> |
| #include "./inst_vector.h" |
| #include "./image_iter_common.h" |
| |
| namespace mxnet { |
| namespace io { |
| // iterator on image recordio |
| class PrefetcherIter : public IIterator<DataBatch> { |
| public: |
| explicit PrefetcherIter(IIterator<TBlobBatch>* base) |
| : loader_(base), out_(nullptr), length_hint_(-1) {} |
| |
| ~PrefetcherIter() { |
| while (recycle_queue_.size() != 0) { |
| DataBatch* batch = recycle_queue_.front(); |
| recycle_queue_.pop(); |
| delete batch; |
| } |
| delete out_; |
| iter.Destroy(); |
| } |
| |
| void InitParams(const std::vector<std::pair<std::string, std::string> >& kwargs) { |
| std::vector<std::pair<std::string, std::string> > kwargs_left; |
| // init image rec param |
| kwargs_left = param_.InitAllowUnknown(kwargs); |
| CHECK_GT(param_.prefetch_buffer, 0) << "Prefetch_buffer must be positive number"; |
| // maximum prefetch threaded iter internal size |
| const int kMaxPrefetchBuffer = 16; |
| // init thread iter |
| iter.set_max_capacity(kMaxPrefetchBuffer); |
| } |
| |
| virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) { |
| InitParams(kwargs); |
| // use the kwarg to init batch loader |
| loader_->Init(kwargs); |
| length_hint_ = loader_->GetLenHint(); |
| iter.Init( |
| [this](DataBatch** dptr) { |
| if (!loader_->Next()) |
| return false; |
| const TBlobBatch& batch = loader_->Value(); |
| if (*dptr == nullptr) { |
| // allocate databatch |
| *dptr = new DataBatch(); |
| (*dptr)->num_batch_padd = batch.num_batch_padd; |
| (*dptr)->data.resize(batch.data.size()); |
| (*dptr)->index.resize(batch.batch_size); |
| for (size_t i = 0; i < batch.data.size(); ++i) { |
| auto dtype = param_.dtype ? param_.dtype.value() : batch.data[i].type_flag_; |
| auto ctx = ((param_.ctx == PrefetcherParam::kCPUPinned) && (param_.device_id >= 0)) ? |
| Context::CPUPinned(param_.device_id) : |
| Context::CPU(); |
| (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, ctx, false, dtype); |
| } |
| } |
| CHECK(batch.data.size() == (*dptr)->data.size()); |
| // copy data over |
| for (size_t i = 0; i < batch.data.size(); ++i) { |
| if ((*dptr)->data.at(i).shape() != batch.data[i].shape_) { |
| // TODO(zhreshold): memory pool for dynamic shaped data |
| (*dptr)->data.at(i).ReshapeAndAlloc(batch.data[i].shape_); |
| } |
| CHECK_EQ((*dptr)->data.at(i).shape(), batch.data[i].shape_); |
| MSHADOW_TYPE_SWITCH(batch.data[i].type_flag_, DType, { |
| mshadow::Copy(((*dptr)->data)[i].data().FlatTo2D<cpu, DType>(), |
| batch.data[i].FlatTo2D<cpu, DType>()); |
| }); |
| (*dptr)->num_batch_padd = batch.num_batch_padd; |
| } |
| if (batch.inst_index) { |
| std::copy( |
| batch.inst_index, batch.inst_index + batch.batch_size, (*dptr)->index.begin()); |
| } |
| return true; |
| }, |
| [this]() { |
| loader_->BeforeFirst(); |
| length_hint_ = loader_->GetLenHint(); |
| }); |
| } |
| |
| virtual void BeforeFirst(void) { |
| iter.BeforeFirst(); |
| } |
| |
| virtual int64_t GetLenHint(void) const { |
| return length_hint_; |
| } |
| |
| virtual bool Next(void) { |
| if (out_ != nullptr) { |
| recycle_queue_.push(out_); |
| out_ = nullptr; |
| } |
| // do recycle |
| if (recycle_queue_.size() == param_.prefetch_buffer) { |
| DataBatch* old_batch = recycle_queue_.front(); |
| // can be more efficient on engine |
| for (NDArray& arr : old_batch->data) { |
| arr.WaitToWrite(); |
| } |
| recycle_queue_.pop(); |
| iter.Recycle(&old_batch); |
| } |
| return iter.Next(&out_); |
| } |
| virtual const DataBatch& Value(void) const { |
| return *out_; |
| } |
| |
| protected: |
| /*! \brief prefetcher parameters */ |
| PrefetcherParam param_; |
| /*! \brief backend thread */ |
| dmlc::ThreadedIter<DataBatch> iter; |
| /*! \brief internal batch loader */ |
| std::unique_ptr<IIterator<TBlobBatch> > loader_; |
| |
| private: |
| /*! \brief output data */ |
| DataBatch* out_; |
| /*! \brief queue to be recycled */ |
| std::queue<DataBatch*> recycle_queue_; |
| /*! \brief size hint cache */ |
| int64_t length_hint_; |
| }; |
| } // namespace io |
| } // namespace mxnet |
| #endif // MXNET_IO_ITER_PREFETCHER_H_ |