blob: 6a6b03f9c2fb4317ed4a08de7ca504ce9f5b16e1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file iter_sframe_image.cc
* \brief
* \author Bing Xu
*/
#include <mxnet/io.h>
#include <dmlc/base.h>
#include <dmlc/io.h>
#include <dmlc/omp.h>
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <string>
#include <memory>
#include <unity/lib/image_util.hpp>
#include <unity/lib/gl_sframe.hpp>
#include <unity/lib/gl_sarray.hpp>
#include "../../src/io/inst_vector.h"
#include "../../src/io/image_recordio.h"
#include "../../src/io/image_augmenter.h"
#include "../../src/io/iter_prefetcher.h"
#include "../../src/io/iter_normalize.h"
#include "../../src/io/iter_batchloader.h"
namespace mxnet {
namespace io {
struct SFrameParam : public dmlc::Parameter<SFrameParam> {
/*! \brief sframe path */
std::string path_sframe;
std::string data_field;
std::string label_field;
mxnet::TShape data_shape;
mxnet::TShape label_shape;
DMLC_DECLARE_PARAMETER(SFrameParam) {
DMLC_DECLARE_FIELD(path_sframe).set_default("")
.describe("Dataset Param: path to image dataset sframe");
DMLC_DECLARE_FIELD(data_field).set_default("data")
.describe("Dataset Param: data column in sframe");
DMLC_DECLARE_FIELD(label_field).set_default("label")
.describe("Dataset Param: label column in sframe");
DMLC_DECLARE_FIELD(data_shape)
.describe("Dataset Param: input data instance shape");
DMLC_DECLARE_FIELD(label_shape)
.describe("Dataset Param: input label instance shape");
}
}; // struct SFrameImageParam
class SFrameIterBase : public IIterator<DataInst> {
public:
SFrameIterBase() {}
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.InitAllowUnknown(kwargs);
sframe_ = graphlab::gl_sframe(param_.path_sframe)[{param_.data_field, param_.label_field}];
range_it_.reset(new graphlab::gl_sframe_range(sframe_.range_iterator()));
this->BeforeFirst();
}
virtual ~SFrameIterBase() {}
virtual void BeforeFirst() {
idx_ = 0;
*range_it_ = sframe_.range_iterator();
current_it_ = range_it_->begin();
}
virtual const DataInst &Value(void) const {
return out_;
}
virtual bool Next() = 0;
protected:
/*! \brief index of instance */
index_t idx_;
/*! \brief output of sframe iterator */
DataInst out_;
/*! \brief temp space */
InstVector tmp_;
/*! \brief sframe iter parameter */
SFrameParam param_;
/*! \brief sframe object*/
graphlab::gl_sframe sframe_;
/*! \brief sframe range iterator */
std::unique_ptr<graphlab::gl_sframe_range> range_it_;
/*! \brief current iterator in range iterator */
graphlab::gl_sframe_range::iterator current_it_;
protected:
/*! \brief copy data */
template<int dim>
void Copy_(mshadow::Tensor<cpu, dim> tensor, const graphlab::flex_vec &vec) {
CHECK_EQ(tensor.shape_.Size(), vec.size());
CHECK_EQ(tensor.CheckContiguous(), true);
mshadow::Tensor<cpu, 1> flatten(tensor.dptr_, mshadow::Shape1(tensor.shape_.Size()));
for (index_t i = 0; i < vec.size(); ++i) {
flatten[i] = static_cast<float>(vec[i]);
}
}
}; // class SFrameIterBase
class SFrameImageIter : public SFrameIterBase {
public:
SFrameImageIter() :
augmenter_(new ImageAugmenter()), prnd_(new common::RANDOM_ENGINE(8964)) {}
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
Parent::Init(kwargs);
augmenter_->Init(kwargs);
CHECK_EQ(Parent::param_.data_shape.ndim(), 3)
<< "Image shpae must be (channel, height, width)";
}
bool Next(void) override {
if (Parent::current_it_ == Parent::range_it_->end()) {
return false;
}
graphlab::image_type gl_img = (*Parent::current_it_)[0];
graphlab::flex_vec gl_label = (*Parent::current_it_)[1];
// TODO(bing): check not decoded
// TODO(bing): check img shape
CHECK_EQ(gl_label.size(), Parent::param_.label_shape.Size()) << "Label shape does not match";
const unsigned char *raw_data = gl_img.get_image_data();
cv::Mat res;
cv::Mat buf(1, gl_img.m_image_data_size, CV_8U, const_cast<unsigned char*>(raw_data));
res = cv::imdecode(buf, -1);
res = augmenter_->Process(res, prnd_.get());
const int n_channels = res.channels();
if (!tmp_.Size()) {
tmp_.Push(Parent::idx_++,
Parent::param_.data_shape.get<3>(),
Parent::param_.label_shape.get<1>());
}
mshadow::Tensor<cpu, 3> data = Parent::tmp_.data().Back();
std::vector<int> swap_indices;
if (n_channels == 1) swap_indices = {0};
if (n_channels == 3) swap_indices = {2, 1, 0};
for (int i = 0; i < res.rows; ++i) {
uchar* im_data = res.ptr<uchar>(i);
for (int j = 0; j < res.cols; ++j) {
for (int k = 0; k < n_channels; ++k) {
data[k][i][j] = im_data[swap_indices[k]];
}
im_data += n_channels;
}
}
mshadow::Tensor<cpu, 1> label = Parent::tmp_.label().Back();
Parent::Copy_<1>(label, gl_label);
res.release();
out_ = Parent::tmp_[0];
++current_it_;
return true;
}
private:
/*! \brief parent type */
typedef SFrameIterBase Parent;
/*! \brief image augmenter */
std::unique_ptr<ImageAugmenter> augmenter_;
/*! \brief randim generator*/
std::unique_ptr<common::RANDOM_ENGINE> prnd_;
}; // class SFrameImageIter
class SFrameDataIter : public SFrameIterBase {
public:
bool Next() override {
if (Parent::current_it_ == Parent::range_it_->end()) {
return false;
}
graphlab::flex_vec gl_data = (*Parent::current_it_)[0];
graphlab::flex_vec gl_label = (*Parent::current_it_)[1];
CHECK_EQ(gl_data.size(), Parent::param_.data_shape.Size()) << "Data shape does not match";
CHECK_EQ(gl_label.size(), Parent::param_.label_shape.Size()) << "Label shape does not match";
if (!Parent::tmp_.Size()) {
Parent::tmp_.Push(Parent::idx_++,
Parent::param_.data_shape.get<3>(),
Parent::param_.label_shape.get<1>());
}
mshadow::Tensor<cpu, 3> data = Parent::tmp_.data().Back();
Parent::Copy_<3>(data, gl_data);
mshadow::Tensor<cpu, 1> label = Parent::tmp_.label().Back();
Parent::Copy_<1>(label, gl_label);
out_ = Parent::tmp_[0];
++current_it_;
return true;
}
private:
/*! \brief parent type */
typedef SFrameIterBase Parent;
}; // class SFrameDataIter
DMLC_REGISTER_PARAMETER(SFrameParam);
MXNET_REGISTER_IO_ITER(SFrameImageIter)
.describe("Naive SFrame image iterator prototype")
.add_arguments(SFrameParam::__FIELDS__())
.add_arguments(BatchParam::__FIELDS__())
.add_arguments(PrefetcherParam::__FIELDS__())
.add_arguments(ImageAugmentParam::__FIELDS__())
.add_arguments(ImageNormalizeParam::__FIELDS__())
.set_body([]() {
return new PrefetcherIter(
new BatchLoader(
new ImageNormalizeIter(
new SFrameImageIter())));
});
MXNET_REGISTER_IO_ITER(SFrameDataIter)
.describe("Naive SFrame data iterator prototype")
.add_arguments(SFrameParam::__FIELDS__())
.add_arguments(BatchParam::__FIELDS__())
.add_arguments(PrefetcherParam::__FIELDS__())
.set_body([]() {
return new PrefetcherIter(
new BatchLoader(
new SFrameDataIter()));
});
} // namespace io
} // namespace mxnet