blob: 153e3c4c056af9eec6cfef56a059d663762f0012 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file dataset.cc
* \brief High performance datasets implementation
*/
#include <dmlc/parameter.h>
#include <dmlc/recordio.h>
#include <dmlc/io.h>
#include <mxnet/io.h>
#include <mxnet/ndarray.h>
#include <mxnet/tensor_blob.h>
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include <thread>
#include "../imperative/cached_op.h"
#include "../imperative/naive_cached_op.h"
#include "../ndarray/ndarray_function.h"
#if MXNET_USE_OPENCV
#include <opencv2/opencv.hpp>
#include "./opencv_compatibility.h"
#endif // MXNET_USE_OPENCV
namespace mxnet {
namespace io {
struct RecordFileDatasetParam : public dmlc::Parameter<RecordFileDatasetParam> {
std::string rec_file;
std::string idx_file;
// declare parameters
DMLC_DECLARE_PARAMETER(RecordFileDatasetParam) {
DMLC_DECLARE_FIELD(rec_file).describe("The absolute path of record file.");
DMLC_DECLARE_FIELD(idx_file).describe("The path of the idx file.");
}
}; // struct RecordFileDatasetParam
DMLC_REGISTER_PARAMETER(RecordFileDatasetParam);
class RecordFileDataset final : public Dataset {
public:
explicit RecordFileDataset(const std::vector<std::pair<std::string, std::string>>& kwargs) {
std::vector<std::pair<std::string, std::string>> kwargs_left;
param_.InitAllowUnknown(kwargs);
// read and process idx file
dmlc::Stream* idx_stream = dmlc::Stream::Create(param_.idx_file.c_str(), "r");
dmlc::istream is(idx_stream);
size_t key, idx;
while (is >> key >> idx) {
idx_[key] = idx;
}
delete idx_stream;
}
uint64_t GetLen() const override {
return idx_.size();
}
bool GetItem(uint64_t idx, std::vector<NDArray>* ret) override {
ret->resize(1);
auto& out = (*ret)[0];
static thread_local std::unique_ptr<dmlc::Stream> stream;
static thread_local std::unique_ptr<dmlc::RecordIOReader> reader;
if (!reader) {
auto s = dmlc::Stream::Create(param_.rec_file.c_str(), "r");
stream.reset(s);
reader = std::make_unique<dmlc::RecordIOReader>(s);
}
size_t pos = idx_[static_cast<size_t>(idx)];
reader->Seek(pos);
static thread_local std::string read_buff;
if (reader->NextRecord(&read_buff)) {
const char* buf = read_buff.c_str();
const size_t size = read_buff.size();
out = NDArray(TShape({static_cast<dim_t>(size)}), Context::CPU(), false, mshadow::kInt8);
TBlob dst = out.data();
RunContext rctx{Context::CPU(), nullptr, nullptr};
mxnet::ndarray::Copy<cpu, cpu>(TBlob(const_cast<void*>(reinterpret_cast<const void*>(buf)),
out.shape(),
cpu::kDevMask,
out.dtype(),
0),
&dst,
Context::CPU(),
Context::CPU(),
rctx);
}
return true;
}
private:
/*! \brief parameters */
RecordFileDatasetParam param_;
/*! \brief indices */
std::unordered_map<size_t, size_t> idx_;
};
MXNET_REGISTER_IO_DATASET(RecordFileDataset)
.describe("MXNet Record File Dataset")
.add_arguments(RecordFileDatasetParam::__FIELDS__())
.set_body([](const std::vector<std::pair<std::string, std::string>>& kwargs) {
return new RecordFileDataset(kwargs);
});
struct ImageRecordFileDatasetParam : public dmlc::Parameter<ImageRecordFileDatasetParam> {
std::string rec_file;
std::string idx_file;
int flag;
// declare parameters
DMLC_DECLARE_PARAMETER(ImageRecordFileDatasetParam) {
DMLC_DECLARE_FIELD(rec_file).describe("The absolute path of record file.");
DMLC_DECLARE_FIELD(idx_file).describe("The path of the idx file.");
DMLC_DECLARE_FIELD(flag).set_default(1).describe(
"If 1, always convert to colored, if 0 always convert to grayscale.");
}
}; // struct ImageRecordFileDatasetParam
DMLC_REGISTER_PARAMETER(ImageRecordFileDatasetParam);
#if MXNET_USE_OPENCV
template <int n_channels>
void SwapImageChannels(const cv::Mat& img, NDArray* arr) {
int swap_indices[n_channels]; // NOLINT(*)
if (n_channels == 1) {
swap_indices[0] = 0;
} else if (n_channels == 3) {
swap_indices[0] = 2;
swap_indices[1] = 1;
swap_indices[2] = 0;
} else if (n_channels == 4) {
swap_indices[0] = 2;
swap_indices[1] = 1;
swap_indices[2] = 0;
swap_indices[3] = 3;
}
TShape arr_shape = TShape({img.rows, img.cols, n_channels});
if (arr->is_none() || arr->shape() != arr_shape || arr->ctx() != mxnet::Context::CPU(0) ||
arr->dtype() != mshadow::kUint8 || arr->storage_type() != kDefaultStorage) {
*arr = NDArray(arr_shape, mxnet::Context::CPU(0), false, mshadow::kUint8);
}
auto ptr = static_cast<uint8_t*>(arr->data().dptr_);
// swap channels while copying elements into buffer
for (int i = 0; i < img.rows; ++i) {
const uint8_t* im_data = img.ptr<uint8_t>(i);
uint8_t* buffer_data = ptr + i * img.cols * n_channels;
for (int j = 0; j < img.cols; ++j) {
for (int k = 0; k < n_channels; ++k) {
buffer_data[k] = im_data[swap_indices[k]];
}
im_data += n_channels;
buffer_data += n_channels;
}
}
}
#endif
/*! \brief Struct for unpack recordio header */
#pragma pack(1)
struct IRHeader {
uint32_t flag;
float label;
uint64_t id;
uint64_t id2;
}; // struct IRHeader
class ImageRecordFileDataset : public Dataset {
public:
explicit ImageRecordFileDataset(const std::vector<std::pair<std::string, std::string>>& kwargs) {
std::vector<std::pair<std::string, std::string>> kwargs_left;
param_.InitAllowUnknown(kwargs);
base_ = std::make_shared<RecordFileDataset>(kwargs);
}
uint64_t GetLen() const override {
return base_->GetLen();
}
bool GetItem(uint64_t idx, std::vector<NDArray>* ret) override {
CHECK_LT(idx, GetLen());
std::vector<NDArray> raw;
if (!base_->GetItem(idx, &raw))
return false;
CHECK_EQ(raw.size(), 1U) << "RecordFileDataset should return size 1 NDArray vector";
uint8_t* s = reinterpret_cast<uint8_t*>(raw[0].data().dptr_);
size_t size = raw[0].shape().Size();
CHECK_GT(size, sizeof(IRHeader)) << "Invalid size of bytes from Record File";
IRHeader header;
std::memcpy(&header, s, sizeof(header));
size -= sizeof(header);
s += sizeof(header);
NDArray label = NDArray(Context::CPU(), mshadow::default_type_flag);
RunContext rctx{Context::CPU(), nullptr, nullptr};
if (header.flag > 0) {
auto label_shape = header.flag <= 1 ? TShape(0, 1) : TShape({header.flag});
label.ReshapeAndAlloc(label_shape);
TBlob dst = label.data();
mxnet::ndarray::Copy<cpu, cpu>(
TBlob(reinterpret_cast<void*>(s), label.shape(), cpu::kDevMask, label.dtype(), 0),
&dst,
Context::CPU(),
Context::CPU(),
rctx);
s += sizeof(float) * header.flag;
size -= sizeof(float) * header.flag;
} else {
// label is a scalar with ndim() == 0
label.ReshapeAndAlloc(TShape(0, 1));
TBlob dst = label.data();
*(dst.dptr<float>()) = header.label;
}
ret->resize(2);
(*ret)[1] = label;
#if MXNET_USE_OPENCV
cv::Mat buf(1, size, CV_8U, s);
cv::Mat res = cv::imdecode(buf, param_.flag);
CHECK(!res.empty()) << "Decoding failed. Invalid image file.";
const int n_channels = res.channels();
if (n_channels == 1) {
SwapImageChannels<1>(res, &(ret->at(0)));
} else if (n_channels == 3) {
SwapImageChannels<3>(res, &(ret->at(0)));
} else if (n_channels == 4) {
SwapImageChannels<4>(res, &(ret->at(0)));
}
return true;
#else
LOG(FATAL) << "Opencv is needed for image decoding.";
#endif
return false; // should not reach here
}
private:
/*! \brief parameters */
ImageRecordFileDatasetParam param_;
/*! \brief base recordIO reader */
std::shared_ptr<RecordFileDataset> base_;
};
MXNET_REGISTER_IO_DATASET(ImageRecordFileDataset)
.describe("MXNet Image Record File Dataset")
.add_arguments(ImageRecordFileDatasetParam::__FIELDS__())
.set_body([](const std::vector<std::pair<std::string, std::string>>& kwargs) {
return new ImageRecordFileDataset(kwargs);
});
struct ImageSequenceDatasetParam : public dmlc::Parameter<ImageSequenceDatasetParam> {
/*! \brief the list of absolute image paths, separated by \0 characters */
std::string img_list;
/*! \brief the path separator character, by default it's ; */
char path_sep;
/*! \brief If flag is 0, always convert to grayscale(1 channel).
* If flag is 1, always convert to colored (3 channels).
* If flag is -1, keep channels unchanged.
*/
int flag;
// declare parameters
DMLC_DECLARE_PARAMETER(ImageSequenceDatasetParam) {
DMLC_DECLARE_FIELD(img_list).describe("The list of image absolute paths.");
DMLC_DECLARE_FIELD(path_sep).set_default('|').describe(
"The path separator for joined image paths.");
DMLC_DECLARE_FIELD(flag).set_default(1).describe(
"If 1, always convert to colored, if 0 always convert to grayscale.");
}
}; // struct ImageSequenceDatasetParam
DMLC_REGISTER_PARAMETER(ImageSequenceDatasetParam);
class ImageSequenceDataset final : public Dataset {
public:
explicit ImageSequenceDataset(const std::vector<std::pair<std::string, std::string>>& kwargs) {
std::vector<std::pair<std::string, std::string>> kwargs_left;
param_.InitAllowUnknown(kwargs);
img_list_ = dmlc::Split(param_.img_list, param_.path_sep);
}
uint64_t GetLen() const override {
return img_list_.size();
}
bool GetItem(uint64_t idx, std::vector<NDArray>* ret) override {
#if MXNET_USE_OPENCV
CHECK_LT(idx, img_list_.size())
<< "GetItem index: " << idx << " out of bound: " << img_list_.size();
cv::Mat res = cv::imread(img_list_[idx], param_.flag);
CHECK(!res.empty()) << "Decoding failed. Invalid image file.";
const int n_channels = res.channels();
ret->resize(1);
if (n_channels == 1) {
SwapImageChannels<1>(res, &(ret->at(0)));
} else if (n_channels == 3) {
SwapImageChannels<3>(res, &(ret->at(0)));
} else if (n_channels == 4) {
SwapImageChannels<4>(res, &(ret->at(0)));
}
return true;
#else
LOG(FATAL) << "Opencv is needed for image decoding.";
#endif
return false;
}
private:
/*! \brief parameters */
ImageSequenceDatasetParam param_;
/*! \brief image list */
std::vector<std::string> img_list_;
};
MXNET_REGISTER_IO_DATASET(ImageSequenceDataset)
.describe("Image Sequence Dataset")
.add_arguments(ImageSequenceDatasetParam::__FIELDS__())
.set_body([](const std::vector<std::pair<std::string, std::string>>& kwargs) {
return new ImageSequenceDataset(kwargs);
});
struct NDArrayDatasetParam : public dmlc::Parameter<NDArrayDatasetParam> {
/*! \brief the source ndarray */
std::intptr_t arr;
// declare parameters
DMLC_DECLARE_PARAMETER(NDArrayDatasetParam) {
DMLC_DECLARE_FIELD(arr).describe("Pointer to NDArray.");
}
}; // struct NDArrayDatasetParam
DMLC_REGISTER_PARAMETER(NDArrayDatasetParam);
class NDArrayDataset final : public Dataset {
public:
explicit NDArrayDataset(const std::vector<std::pair<std::string, std::string>>& kwargs) {
param_.InitAllowUnknown(kwargs);
data_ = *(static_cast<NDArray*>(reinterpret_cast<void*>(param_.arr)));
if (data_.shape().ndim() < 1) {
LOG(FATAL) << "NDArray with no dim is not iterable";
}
size_ = data_.shape().begin()[0];
}
uint64_t GetLen() const override {
return size_;
}
bool GetItem(uint64_t idx, std::vector<NDArray>* rets) override {
CHECK_LT(idx, size_) << "GetItem index: " << idx << " out of bound: " << size_;
rets->resize(1);
auto& ret = (*rets)[0];
ret = data_.Slice(idx, idx + 1);
if (ret.shape().ndim() > 1) {
// remove first dim to be consistent with numpy
TShape new_shape;
new_shape.assign(ret.shape().begin() + 1, ret.shape().end());
ret = ret.Reshape(new_shape);
} else {
if (data_.shape().ndim() == 1) {
// scalar
TShape new_shape(0, 1);
ret = ret.Reshape(new_shape);
}
}
return true;
}
private:
/*! \brief parameters */
NDArrayDatasetParam param_;
/*! \brief stored ndarray */
NDArray data_;
/*! \brief stored ndarray shape */
int64_t size_;
}; // class NDArrayDataset
MXNET_REGISTER_IO_DATASET(NDArrayDataset)
.describe("Single NDArray Dataset")
.add_arguments(NDArrayDatasetParam::__FIELDS__())
.set_body([](const std::vector<std::pair<std::string, std::string>>& kwargs) {
return new NDArrayDataset(kwargs);
});
struct GroupDatasetParam : public dmlc::Parameter<GroupDatasetParam> {
/*! \brief the source ndarray */
Tuple<std::intptr_t> datasets;
// declare parameters
DMLC_DECLARE_PARAMETER(GroupDatasetParam) {
DMLC_DECLARE_FIELD(datasets).describe("A small set of pointers to other c++ datasets.");
}
}; // struct GroupDatasetParam
DMLC_REGISTER_PARAMETER(GroupDatasetParam);
class GroupDataset final : public Dataset {
public:
explicit GroupDataset(const std::vector<std::pair<std::string, std::string>>& kwargs) {
std::vector<std::pair<std::string, std::string>> kwargs_left;
param_.InitAllowUnknown(kwargs);
auto childs = param_.datasets;
childs_.reserve(childs.ndim());
size_t child_cnt = 0;
for (auto child : childs) {
auto d = *static_cast<std::shared_ptr<Dataset>*>(reinterpret_cast<void*>(child));
if (child_cnt == 0) {
size_ = d->GetLen();
} else {
CHECK_EQ(size_, d->GetLen()) << "All child dataset of GroupDataset must be identical "
<< "Given mismatch: " << size_ << " vs " << d->GetLen();
}
childs_.emplace_back(d);
child_cnt++;
}
}
uint64_t GetLen() const override {
return size_;
}
bool GetItem(uint64_t idx, std::vector<NDArray>* ret) override {
CHECK_LT(idx, size_) << "GetItem index: " << idx << " out of bound: " << size_;
ret->clear();
for (const auto& child : childs_) {
std::vector<NDArray> temp_ret;
if (!child->GetItem(idx, &temp_ret))
return false;
ret->insert(ret->end(), temp_ret.begin(), temp_ret.end());
}
return true;
}
private:
/*! \brief parameters */
GroupDatasetParam param_;
/*! \brief stored child datasets */
std::vector<std::shared_ptr<Dataset>> childs_;
/*! \brief overall dataset size, equals to all child datasets */
uint64_t size_;
}; // class GroupDataset
MXNET_REGISTER_IO_DATASET(GroupDataset)
.describe("Grouped Dataset that combine a bunch of datasets")
.add_arguments(GroupDatasetParam::__FIELDS__())
.set_body([](const std::vector<std::pair<std::string, std::string>>& kwargs) {
return new GroupDataset(kwargs);
});
struct IndexedDatasetParam : public dmlc::Parameter<IndexedDatasetParam> {
/*! \brief the base dataset */
std::intptr_t base;
/*! \brief the indices */
Tuple<uint64_t> indices;
// declare parameters
DMLC_DECLARE_PARAMETER(IndexedDatasetParam) {
DMLC_DECLARE_FIELD(base).describe(
"Pointer to the internal c++ dataset that is going to be indexed.");
DMLC_DECLARE_FIELD(indices).describe(
"The indices for the internal dataset. Output[i] will be base[indices[i]].");
}
}; // struct IndexedDatasetParam
DMLC_REGISTER_PARAMETER(IndexedDatasetParam);
class IndexedDataset final : public Dataset {
public:
explicit IndexedDataset(const std::vector<std::pair<std::string, std::string>>& kwargs) {
param_.InitAllowUnknown(kwargs);
base_data_ = *static_cast<std::shared_ptr<Dataset>*>(reinterpret_cast<void*>(param_.base));
}
uint64_t GetLen() const override {
return param_.indices.ndim();
}
bool GetItem(uint64_t idx, std::vector<NDArray>* ret) override {
CHECK_GT(param_.indices.ndim(), idx)
<< "IndexError: " << idx << " from total: " << param_.indices.ndim();
auto new_idx = param_.indices[idx];
CHECK_GT(base_data_->GetLen(), new_idx)
<< "IndexError: " << new_idx
<< " from original dataset with size: " << base_data_->GetLen();
return base_data_->GetItem(new_idx, ret);
}
private:
/*! \brief parameters */
IndexedDatasetParam param_;
/*! \brief stored child dataset */
std::shared_ptr<Dataset> base_data_;
}; // class IndexedDataset
MXNET_REGISTER_IO_DATASET(IndexedDataset)
.describe("Grouped Dataset that combine a bunch of datasets")
.add_arguments(IndexedDatasetParam::__FIELDS__())
.set_body([](const std::vector<std::pair<std::string, std::string>>& kwargs) {
return new IndexedDataset(kwargs);
});
struct LazyTransformDatasetParam : public dmlc::Parameter<LazyTransformDatasetParam> {
/*! \brief the source ndarray */
std::intptr_t cached_op;
/*! \brief internal dataset */
std::intptr_t dataset;
/*! \brief indices for items that needs transformation */
Tuple<int> transform_indices;
/*! \brief is_scalar information for outputs */
Tuple<int> scalar_outputs;
// declare parameters
DMLC_DECLARE_PARAMETER(LazyTransformDatasetParam) {
DMLC_DECLARE_FIELD(cached_op).describe("Pointer to cached transform function.");
DMLC_DECLARE_FIELD(dataset).describe("Pointer to internal dataset.");
DMLC_DECLARE_FIELD(transform_indices)
.set_default(Tuple<int>({}))
.describe(
"The indices for dataset items that need to be transformed/processed. "
"If `transform_indices` is empty(default), "
"then all items will be processed.");
DMLC_DECLARE_FIELD(scalar_outputs)
.describe("Indicate whether outputs are scalars, the size must match the output size.");
}
}; // struct LazyTransformDatasetParam
DMLC_REGISTER_PARAMETER(LazyTransformDatasetParam);
class LazyTransformDataset final : public Dataset {
public:
LazyTransformDataset(const LazyTransformDataset& other) {
this->param_ = other.param_;
this->pass_through_indices_ = other.pass_through_indices_;
this->use_input_indices_ = other.use_input_indices_;
this->num_outputs_ = other.num_outputs_;
this->cached_op_ =
std::make_shared<NaiveCachedOp>(other.cached_op_->sym_, other.cached_op_->flags_);
this->base_data_ = other.base_data_;
}
explicit LazyTransformDataset(const std::vector<std::pair<std::string, std::string>>& kwargs) {
param_.InitAllowUnknown(kwargs);
auto op = *static_cast<CachedOpPtr*>(reinterpret_cast<void*>(param_.cached_op));
cached_op_ = std::make_shared<NaiveCachedOp>(op->sym_, op->flags_);
base_data_ = *static_cast<std::shared_ptr<Dataset>*>(reinterpret_cast<void*>(param_.dataset));
// use first item to calculate size info
CHECK_GT(GetLen(), 0) << "LazyTransformDataset expect the base dataset to have at least 1 item";
std::vector<NDArray> inputs;
CHECK(base_data_->GetItem(0, &inputs));
// check output size
CHECK_EQ(param_.scalar_outputs.ndim(), cached_op_->num_outputs())
<< "Output scalar info size: " << param_.scalar_outputs.ndim()
<< " vs. output size: " << cached_op_->num_outputs() << " mismatch!";
// check input size
if (param_.transform_indices.ndim() == 0) {
std::vector<int> default_indices;
default_indices.reserve(cached_op_->num_inputs());
for (size_t i = 0; i < cached_op_->num_inputs(); ++i) {
default_indices.emplace_back(static_cast<int>(i));
}
use_input_indices_ = default_indices;
} else {
use_input_indices_ =
std::vector<int>(param_.transform_indices.begin(), param_.transform_indices.end());
}
CHECK_EQ(use_input_indices_.size(), cached_op_->num_inputs())
<< "Mismatched transform indices and transform inputs: " << use_input_indices_.size()
<< " vs. " << cached_op_->num_inputs();
auto num_inputs = use_input_indices_.size();
CHECK_GE(inputs.size(), num_inputs) << "LazyTransformDataset input size " << inputs.size()
<< " smaller than transform input size: " << num_inputs;
pass_through_indices_.clear();
for (size_t i = 0; i < inputs.size(); ++i) {
// filling output ndarray from unaltered inputs, transformed outputs are already inserted
if (std::find(use_input_indices_.begin(), use_input_indices_.end(), i) ==
use_input_indices_.end()) {
pass_through_indices_.emplace_back(i);
}
}
num_outputs_ = inputs.size() + cached_op_->num_outputs() - cached_op_->num_inputs();
}
~LazyTransformDataset() override = default;
uint64_t GetLen() const override {
return base_data_->GetLen();
}
bool GetItem(uint64_t idx, std::vector<NDArray>* outputs) override {
std::vector<NDArray> inputs;
if (!base_data_->GetItem(idx, &inputs))
return false;
outputs->reserve(num_outputs_);
outputs->resize(cached_op_->num_outputs());
for (auto i : pass_through_indices_) {
outputs->emplace_back(inputs[i]);
}
CHECK_EQ(outputs->size(), num_outputs_);
// workspace for cached op
std::vector<NDArray*> ndinputs;
std::vector<NDArray*> ndoutputs;
ndinputs.reserve(inputs.size());
for (int use_input_indice : use_input_indices_) {
ndinputs.emplace_back(&(inputs[use_input_indice]));
}
ndoutputs.reserve(cached_op_->num_outputs());
CHECK_LE(cached_op_->num_outputs(), outputs->size());
for (size_t i = 0; i < cached_op_->num_outputs(); ++i) {
ndoutputs.emplace_back(&(outputs->at(i)));
}
for (auto& input : inputs) {
input.WaitToRead();
}
CHECK(inputs.size() > 0) << "dataset getitem requires at least one input";
Context default_ctx = inputs[0].ctx();
cached_op_->Forward(cached_op_, ndinputs, ndoutputs, default_ctx);
return true;
}
private:
/*! \brief parameters */
LazyTransformDatasetParam param_;
/*! \brief stored cached op */
NaiveCachedOpPtr cached_op_;
/*! \brief internal dataset */
std::shared_ptr<Dataset> base_data_;
std::vector<int> use_input_indices_;
std::vector<int> pass_through_indices_;
size_t num_outputs_;
}; // class LazyTransformDataset
MXNET_REGISTER_IO_DATASET(LazyTransformDataset)
.describe("Dataset that apply lazy transformation to internal dataset")
.add_arguments(LazyTransformDatasetParam::__FIELDS__())
.set_body([](const std::vector<std::pair<std::string, std::string>>& kwargs) {
return new LazyTransformDataset(kwargs);
});
} // namespace io
} // namespace mxnet