blob: d5c9bd2f4578563f8e27c5d4caca6c8c6e7b9b90 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file iter_sparse_batchloader.h
* \brief define a batch adapter to create sparse tblob batch
*/
#ifndef MXNET_IO_ITER_SPARSE_BATCHLOADER_H_
#define MXNET_IO_ITER_SPARSE_BATCHLOADER_H_
#include <mxnet/io.h>
#include <mxnet/base.h>
#include <dmlc/logging.h>
#include <mshadow/tensor.h>
#include <utility>
#include <vector>
#include <string>
#include "./inst_vector.h"
#include "./image_iter_common.h"
#include "./iter_batchloader.h"
#include "./iter_sparse.h"
namespace mxnet {
namespace io {
/*! \brief create a batch iterator from single instance iterator */
class SparseBatchLoader : public BatchLoader, public SparseIIterator<TBlobBatch> {
public:
explicit SparseBatchLoader(SparseIIterator<DataInst> *base):
BatchLoader(base), sparse_base_(base) {
}
virtual ~SparseBatchLoader(void) {}
inline void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {
BatchLoader::Init(kwargs);
data_stype_ = sparse_base_->GetStorageType(true);
label_stype_ = sparse_base_->GetStorageType(false);
if (param_.round_batch == 0) {
LOG(FATAL) << "sparse batch loader doesn't support round_batch == false yet";
}
}
virtual void BeforeFirst(void) {
BatchLoader::BeforeFirst();
}
virtual bool Next(void) {
out_.num_batch_padd = 0;
out_.batch_size = param_.batch_size;
this->head_ = 0;
// if overflown from previous round, directly return false, until before first is called
if (num_overflow_ != 0) return false;
index_t top = 0;
inst_cache_.clear();
while (sparse_base_->Next()) {
inst_cache_.emplace_back(sparse_base_->Value());
if (inst_cache_.size() >= param_.batch_size) break;
}
// no more data instance
if (inst_cache_.size() == 0) {
return false;
}
if (inst_cache_.size() < param_.batch_size) {
CHECK_GT(param_.round_batch, 0);
num_overflow_ = 0;
sparse_base_->BeforeFirst();
for (; inst_cache_.size() < param_.batch_size; ++num_overflow_) {
CHECK(sparse_base_->Next()) << "number of input must be bigger than batch size";
inst_cache_.emplace_back(sparse_base_->Value());
}
}
out_.num_batch_padd = num_overflow_;
CHECK_EQ(inst_cache_.size(), param_.batch_size);
this->InitDataFromBatch();
for (size_t j = 0; j < inst_cache_.size(); j++) {
const auto& d = inst_cache_[j];
out_.inst_index[top] = d.index;
// TODO(haibin) double check the type?
int64_t unit_size = 0;
for (size_t i = 0; i < d.data.size(); ++i) {
// indptr tensor
if (IsIndPtr(i)) {
auto indptr = data_[i].get<cpu, 1, int64_t>();
if (j == 0) indptr[0] = 0;
indptr[j + 1] = indptr[j] + unit_size;
offsets_[i] = j;
} else {
// indices and values tensor
unit_size = d.data[i].shape_.Size();
MSHADOW_TYPE_SWITCH(data_[i].type_flag_, DType, {
const auto begin = offsets_[i];
const auto end = offsets_[i] + unit_size;
mshadow::Copy(data_[i].get<cpu, 1, DType>().Slice(begin, end),
d.data[i].get_with_shape<cpu, 1, DType>(mshadow::Shape1(unit_size)));
});
offsets_[i] += unit_size;
}
}
}
return true;
}
virtual const TBlobBatch &Value(void) const {
return BatchLoader::Value();
}
virtual const NDArrayStorageType GetStorageType(bool is_data) const {
return sparse_base_->GetStorageType(is_data);
}
virtual const TShape GetShape(bool is_data) const {
TShape inst_shape = sparse_base_->GetShape(is_data);
std::vector<index_t> shape_vec;
shape_vec.push_back(param_.batch_size);
for (index_t dim = 0; dim < inst_shape.ndim(); ++dim) {
shape_vec.push_back(inst_shape[dim]);
}
return TShape(shape_vec.begin(), shape_vec.end());
}
private:
/*! \brief base sparse iterator */
SparseIIterator<DataInst> *sparse_base_;
/*! \brief data instances */
std::vector<DataInst> inst_cache_;
/*! \brief data storage type */
NDArrayStorageType data_stype_;
/*! \brief data label type */
NDArrayStorageType label_stype_;
/*! \brief tensor offset for slicing */
std::vector<size_t> offsets_;
// check whether ith position is the indptr tensor for a CSR tensor
inline bool IsIndPtr(size_t i) {
auto data_num_aux = num_aux_data(data_stype_);
auto label_num_aux = num_aux_data(label_stype_);
auto label_indptr_offset = data_num_aux + 1 + label_num_aux;
// data indptr
if (i == data_num_aux && data_stype_ == kCSRStorage) {
return true;
}
// label indptr
if (i == label_indptr_offset && label_stype_ == kCSRStorage && data_stype_ == kCSRStorage) {
return true;
}
return false;
}
// initialize the data holder by using from the batch
inline void InitDataFromBatch() {
CHECK(data_stype_ == kCSRStorage || label_stype_ == kCSRStorage);
CHECK_GT(inst_cache_.size(), 0);
out_.data.clear();
data_.clear();
offsets_.clear();
size_t total_size = inst_cache_[0].data.size();
data_.resize(total_size);
offsets_.resize(total_size, 0);
std::vector<size_t> vec_sizes(total_size, 0);
// accumulate the memory required for a batch
for (size_t i = 0; i < total_size; ++i) {
size_t size = 0;
// vec_size for indptr
if (IsIndPtr(i)) {
size = param_.batch_size + 1;
} else {
for (const auto &d : inst_cache_) size += d.data[i].shape_.Size();
}
vec_sizes[i] = size;
}
CHECK_EQ(vec_sizes[0], vec_sizes[1]);
for (size_t i = 0; i < total_size; ++i) {
int src_type_flag = inst_cache_[0].data[i].type_flag_;
// init object attributes
TShape dst_shape(mshadow::Shape1(vec_sizes[i]));
data_[i].resize(mshadow::Shape1(vec_sizes[i]), src_type_flag);
CHECK(data_[i].dptr_ != nullptr);
out_.data.push_back(TBlob(data_[i].dptr_, dst_shape, cpu::kDevMask, src_type_flag));
}
}
}; // class BatchLoader
} // namespace io
} // namespace mxnet
#endif // MXNET_IO_ITER_SPARSE_BATCHLOADER_H_