blob: a07089fecc7dd741572e8679e5116e7108f23b9d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file iter_sampler.cc
* \brief The sampler iterator for access dataset elements.
*/
#include <dmlc/parameter.h>
#include <mshadow/random.h>
#include <mxnet/io.h>
#include <mxnet/base.h>
#include <mxnet/resource.h>
#include <memory>
#include <numeric>
#include "../common/utils.h"
#include "./iter_batchloader.h"
#include "./iter_prefetcher.h"
namespace mxnet {
namespace io {
struct SequentialSamplerParam : public dmlc::Parameter<SequentialSamplerParam> {
/*! \brief Length of the sequence. */
size_t length;
/*! \brief start index.*/
int start;
// declare parameters
DMLC_DECLARE_PARAMETER(SequentialSamplerParam) {
DMLC_DECLARE_FIELD(length).describe("Length of the sequence.");
DMLC_DECLARE_FIELD(start).set_default(0).describe("Start of the index.");
}
}; // struct SequentialSamplerParam
DMLC_REGISTER_PARAMETER(SequentialSamplerParam);
class SequentialSampler : public IIterator<DataInst> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.InitAllowUnknown(kwargs);
indices_.resize(param_.length);
std::iota(std::begin(indices_), std::end(indices_), 0); // fill like arange
out_.data.resize(1);
}
void BeforeFirst() override {
pos_ = 0;
}
int64_t GetLenHint() const override {
return static_cast<int64_t>(indices_.size());
}
bool Next() override {
if (pos_ < indices_.size()) {
int64_t* ptr = indices_.data() + pos_;
out_.data[0] = TBlob(ptr,
TShape({
1,
}),
cpu::kDevMask,
0);
++pos_;
return true;
}
return false;
}
const DataInst& Value() const override {
return out_;
}
private:
/*! \brief Stored integer indices */
std::vector<int64_t> indices_;
/*! \brief current position for iteration */
std::size_t pos_;
/*! \brief data for next value */
DataInst out_;
/*! \brief arguments */
SequentialSamplerParam param_;
}; // class SequentialSampler
MXNET_REGISTER_IO_ITER(SequentialSampler)
.describe(R"code(Returns the sequential sampler iterator.
)code" ADD_FILELINE)
.add_arguments(SequentialSamplerParam::__FIELDS__())
.add_arguments(BatchSamplerParam::__FIELDS__())
.set_body([]() { return new BatchSampler(new SequentialSampler()); });
struct RandomSamplerParam : public dmlc::Parameter<RandomSamplerParam> {
/*! \brief Length of the sequence. */
size_t length;
// declare parameters
DMLC_DECLARE_PARAMETER(RandomSamplerParam) {
DMLC_DECLARE_FIELD(length).describe("Length of the sequence.");
}
}; // struct RandomSamplerParam
DMLC_REGISTER_PARAMETER(RandomSamplerParam);
class RandomSampler : public IIterator<DataInst> {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.InitAllowUnknown(kwargs);
indices_.resize(param_.length);
std::iota(std::begin(indices_), std::end(indices_), 0); // fill like arange
mshadow::Random<cpu>* ctx_rng = ResourceManager::Get()
->Request(Context::CPU(), ResourceRequest::kRandom)
.get_random<cpu, real_t>(nullptr);
rng_ = std::make_unique<common::RANDOM_ENGINE>(ctx_rng->GetSeed());
out_.data.resize(1);
BeforeFirst();
}
void BeforeFirst() override {
std::shuffle(std::begin(indices_), std::end(indices_), *rng_);
pos_ = 0;
}
int64_t GetLenHint() const override {
return static_cast<int64_t>(indices_.size());
}
bool Next() override {
if (pos_ < indices_.size()) {
int64_t* ptr = indices_.data() + pos_;
out_.data[0] = TBlob(ptr,
TShape({
1,
}),
cpu::kDevMask,
0);
++pos_;
return true;
}
return false;
}
const DataInst& Value() const override {
return out_;
}
private:
/*! \brief Stored integer indices */
std::vector<int64_t> indices_;
/*! \brief current position for iteration */
std::size_t pos_;
/*! \brief data for next value */
DataInst out_;
/*! \brief random generator engine */
std::unique_ptr<std::mt19937> rng_;
/*! \brief arguments */
RandomSamplerParam param_;
}; // class RandomSampler
MXNET_REGISTER_IO_ITER(RandomSampler)
.describe(R"code(Returns the random sampler iterator.
)code" ADD_FILELINE)
.add_arguments(RandomSamplerParam::__FIELDS__())
.add_arguments(BatchSamplerParam::__FIELDS__())
.set_body([]() { return new BatchSampler(new RandomSampler()); });
} // namespace io
} // namespace mxnet