blob: 47754470453c7782d11e5f73db608243be6c38a7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2020 by Contributors
* \file dataloader.cc
* \brief Pure c++ backed dataloader implementation
*/
#include <dmlc/parameter.h>
#include <dmlc/omp.h>
#include <mxnet/io.h>
#include "./inst_vector.h"
#include "./iter_prefetcher.h"
#include "../profiler/custom_op_profiler.h"
namespace mxnet {
namespace io {
struct ThreadedDataLoaderParam : public dmlc::Parameter<ThreadedDataLoaderParam> {
/*! \brief Multithread worker number. */
int num_workers;
/*! \brief dataset pointer.*/
std::intptr_t dataset;
/*! \brief sampler pointer.*/
std::intptr_t sampler;
/*! \brief batchify function pointer.*/
std::intptr_t batchify_fn;
/*! \brief pin memory to device id.*/
int pin_device_id;
// declare parameters
DMLC_DECLARE_PARAMETER(ThreadedDataLoaderParam) {
DMLC_DECLARE_FIELD(num_workers).set_default(0)
.describe("Number of thread workers.");
DMLC_DECLARE_FIELD(dataset)
.describe("Pointer to shared Dataset.");
DMLC_DECLARE_FIELD(sampler)
.describe("Pointer to Sampler.");
DMLC_DECLARE_FIELD(batchify_fn)
.describe("Pointer to Batchify function.");
DMLC_DECLARE_FIELD(pin_device_id).set_default(-1)
.describe("If not negative, will move data to pinned memory.");
}
}; // struct ThreadedDataLoaderParam
DMLC_REGISTER_PARAMETER(ThreadedDataLoaderParam);
template<typename DType = real_t>
class ThreadedDataLoader : public IIterator<TBlobBatch> {
public:
ThreadedDataLoader() = default;
// destructor
~ThreadedDataLoader() override = default;
// constructor
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.InitAllowUnknown(kwargs);
int maxthread, threadget;
#pragma omp parallel
{
// be conservative, set number of real cores
maxthread = std::max(omp_get_num_procs(), 1);
}
param_.num_workers = std::min(maxthread, param_.num_workers);
#pragma omp parallel num_threads(param_.num_workers)
{
threadget = omp_get_num_threads();
}
param_.num_workers = std::max(1, threadget);
dataset_ = *static_cast<std::shared_ptr<Dataset>*>(reinterpret_cast<void*>(param_.dataset));
dataset_len_ = dataset_->GetLen();
sampler_ = static_cast<IIterator<DataBatch>* >(reinterpret_cast<void*>(param_.sampler));
batchify_fn_ = *static_cast<BatchifyFunctionPtr*>(reinterpret_cast<void*>(param_.batchify_fn));
this->BeforeFirst();
}
// before first
void BeforeFirst() override {
sampler_->BeforeFirst();
}
int64_t GetLenHint() const override {
return sampler_->GetLenHint();
}
bool Next() override {
bool has_next = sampler_->Next();
if (!has_next) return false;
auto samples = sampler_->Value();
auto batch_size = samples.data[0].shape().Size();
int real_batch_size = batch_size - samples.num_batch_padd;
const int64_t *idx_ptr = static_cast<int64_t*>(
samples.data[0].data().dptr_);
std::vector<int64_t> idx_ptrs;
idx_ptrs.assign(idx_ptr, idx_ptr + real_batch_size);
// __getitem__
std::vector<std::vector<NDArray> > inputs(batch_size);
std::vector<int> is_scalars;
bool profiling = profiler::Profiler::Get()->IsProfiling(profiler::Profiler::kImperative);
if (profiling) {
profiler::CustomOpProfiler::Get()->OnCustomBegin("MXThreadedDataLoaderGetItems");
}
#pragma omp parallel for num_threads(param_.num_workers)
for (int i = 0; i < real_batch_size; ++i) {
omp_exc_.Run([&] {
auto idx = idx_ptrs[i];
CHECK(dataset_->GetItem(idx, &inputs[i]))
<< "Error getting data # " << idx;
});
}
if (profiling) {
profiler::CustomOpProfiler::Get()->OnCustomEnd();
}
omp_exc_.Rethrow();
// pad to normal batch size
for (size_t i = real_batch_size; i < batch_size; ++i) {
inputs[i] = inputs[0];
}
// batchify
if (profiling) {
profiler::CustomOpProfiler::Get()->OnCustomBegin("MXThreadedDataLoaderBatchify");
}
CHECK(batchify_fn_->Batchify(inputs, &batched_buffer_))
<< "Error call batchify inside dataloader";
if (profiling) {
profiler::CustomOpProfiler::Get()->OnCustomEnd();
}
out_.batch_size = batched_buffer_.size();
out_.data.resize(batched_buffer_.size());
for (size_t i = 0; i < batched_buffer_.size(); ++i) {
out_.data[i] = batched_buffer_[i].data();
}
out_.num_batch_padd = samples.num_batch_padd;
return true;
}
const TBlobBatch &Value() const override {
return out_;
}
private:
/*! \brief Params */
ThreadedDataLoaderParam param_;
/*! \brief output */
TBlobBatch out_;
/*! \brief batched buffer */
std::vector<NDArray> batched_buffer_;
/*! \brief pointer to dataset */
std::shared_ptr<Dataset> dataset_;
/*! \brief dataset length */
int64_t dataset_len_;
/*! \brief pointer to sampler iterator */
IIterator<DataBatch> *sampler_;
/*! \brief pointer to batchify function */
BatchifyFunctionPtr batchify_fn_;
/*! \brief OMPException obj to store and rethrow exceptions from omp blocks*/
dmlc::OMPException omp_exc_;
}; // class ThreadedDataLoader
MXNET_REGISTER_IO_ITER(ThreadedDataLoader)
.describe(R"code(Returns a threaded data loader iterator.
)code" ADD_FILELINE)
.add_arguments(ThreadedDataLoaderParam::__FIELDS__())
.add_arguments(PrefetcherParam::__FIELDS__())
.set_body([]() {
return new PrefetcherIter(
new ThreadedDataLoader<mxnet::real_t>());
});
} // namespace io
} // namespace mxnet