blob: 324ec3c1e452256d328a54eefa89b1004823cc9f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2017 by Contributors
* \file export.h
* \brief Export module that takes charge of code generation and document
* Generation for functions exported from R-side
*/
#include <cctype>
#include <cstring>
#include <string>
#include <vector>
#include <iomanip>
#include <sstream>
#include <random>
#include "dmlc/base.h"
#include "dmlc/io.h"
#include "dmlc/timer.h"
#include "dmlc/logging.h"
#include "dmlc/recordio.h"
#include <opencv2/opencv.hpp>
#include "image_recordio.h"
#include "base.h"
#include "im2rec.h"
namespace mxnet {
namespace R {
int GetInterMethod(int inter_method, int old_width, int old_height,
int new_width, int new_height, std::mt19937& prnd) { // NOLINT(*)
if (inter_method == 9) {
if (new_width > old_width && new_height > old_height) {
return 2; // CV_INTER_CUBIC for enlarge
} else if (new_width <old_width && new_height < old_height) {
return 3; // CV_INTER_AREA for shrink
} else {
return 1; // CV_INTER_LINEAR for others
}
} else if (inter_method == 10) {
std::uniform_int_distribution<size_t> rand_uniform_int(0, 4);
return rand_uniform_int(prnd);
} else {
return inter_method;
}
}
IM2REC* IM2REC::Get() {
static IM2REC inst;
return &inst;
}
void IM2REC::InitRcppModule() {
using namespace Rcpp; // NOLINT(*)
IM2REC::Get()->scope_ = ::getCurrentScope();
function("mx.internal.im2rec", &IM2REC::im2rec,
Rcpp::List::create(_["image_lst"],
_["root"],
_["output_rec"],
_["label_width"],
_["pack_label"],
_["new_size"],
_["nsplit"],
_["partid"],
_["center_crop"],
_["quality"],
_["color_mode"],
_["unchanged"],
_["inter_method"],
_["encoding"]),
"");
}
void IM2REC::im2rec(const std::string & image_lst, const std::string & root,
const std::string & output_rec,
int label_width, int pack_label, int new_size, int nsplit,
int partid, int center_crop, int quality,
int color_mode, int unchanged,
int inter_method, std::string encoding) {
// Check parameters ranges
if (color_mode != -1 && color_mode != 0 && color_mode != 1) {
Rcpp::stop("Color mode must be -1, 0 or 1.");
}
if (encoding != std::string(".jpg") && encoding != std::string(".png")) {
Rcpp::stop("Encoding mode must be .jpg or .png.");
}
if (label_width <= 1 && pack_label) {
Rcpp::stop("pack_label can only be used when label_width > 1");
}
if (new_size > 0) {
LOG(INFO) << "New Image Size: Short Edge " << new_size;
} else {
LOG(INFO) << "Keep origin image size";
}
if (center_crop) {
LOG(INFO) << "Center cropping to square";
}
if (color_mode == 0) {
LOG(INFO) << "Use gray images";
}
if (color_mode == -1) {
LOG(INFO) << "Keep original color mode";
}
LOG(INFO) << "Encoding is " << encoding;
if (encoding == std::string(".png") && quality > 9) {
quality = 3;
}
if (inter_method != 1) {
switch (inter_method) {
case 0:
LOG(INFO) << "Use inter_method CV_INTER_NN";
break;
case 2:
LOG(INFO) << "Use inter_method CV_INTER_CUBIC";
break;
case 3:
LOG(INFO) << "Use inter_method CV_INTER_AREA";
break;
case 4:
LOG(INFO) << "Use inter_method CV_INTER_LANCZOS4";
break;
case 9:
LOG(INFO) << "Use inter_method mod auto(cubic for enlarge, area for shrink)";
break;
case 10:
LOG(INFO) << "Use inter_method mod rand(nn/bilinear/cubic/area/lanczos4)";
break;
}
}
std::random_device rd;
std::mt19937 prnd(rd());
using namespace dmlc;
static const size_t kBufferSize = 1 << 20UL;
mxnet::io::ImageRecordIO rec;
size_t imcnt = 0;
double tstart = dmlc::GetTime();
dmlc::InputSplit *flist =
dmlc::InputSplit::Create(image_lst.c_str(), partid, nsplit, "text");
std::ostringstream os;
if (nsplit == 1) {
os << output_rec;
} else {
os << output_rec << ".part" << std::setw(3) << std::setfill('0') << partid;
}
LOG(INFO) << "Write to output: " << os.str();
dmlc::Stream *fo = dmlc::Stream::Create(os.str().c_str(), "w");
LOG(INFO) << "Output: " << os.str();
dmlc::RecordIOWriter writer(fo);
std::string fname, path, blob;
std::vector<unsigned char> decode_buf;
std::vector<unsigned char> encode_buf;
std::vector<int> encode_params;
if (encoding == std::string(".png")) {
encode_params.push_back(CV_IMWRITE_PNG_COMPRESSION);
encode_params.push_back(quality);
LOG(INFO) << "PNG encoding compression: " << quality;
} else {
encode_params.push_back(CV_IMWRITE_JPEG_QUALITY);
encode_params.push_back(quality);
LOG(INFO) << "JPEG encoding quality: " << quality;
}
dmlc::InputSplit::Blob line;
std::vector<float> label_buf(label_width, 0.f);
while (flist->NextRecord(&line)) {
std::string sline(static_cast<char*>(line.dptr), line.size);
std::istringstream is(sline);
if (!(is >> rec.header.image_id[0] >> rec.header.label)) continue;
label_buf[0] = rec.header.label;
for (int k = 1; k < label_width; ++k) {
RCHECK(is >> label_buf[k])
<< "Invalid ImageList, did you provide the correct label_width?";
}
if (pack_label) rec.header.flag = label_width;
rec.SaveHeader(&blob);
if (pack_label) {
size_t bsize = blob.size();
blob.resize(bsize + label_buf.size()*sizeof(float));
memcpy(BeginPtr(blob) + bsize,
BeginPtr(label_buf), label_buf.size()*sizeof(float));
}
RCHECK(std::getline(is, fname));
// eliminate invalid chars in the end
while (fname.length() != 0 &&
(isspace(*fname.rbegin()) || !isprint(*fname.rbegin()))) {
fname.resize(fname.length() - 1);
}
// eliminate invalid chars in beginning.
const char *p = fname.c_str();
while (isspace(*p)) ++p;
path = root + p;
// use "r" is equal to rb in dmlc::Stream
dmlc::Stream *fi = dmlc::Stream::Create(path.c_str(), "r");
decode_buf.clear();
size_t imsize = 0;
while (true) {
decode_buf.resize(imsize + kBufferSize);
size_t nread = fi->Read(BeginPtr(decode_buf) + imsize, kBufferSize);
imsize += nread;
decode_buf.resize(imsize);
if (nread != kBufferSize) break;
}
delete fi;
if (unchanged != 1) {
cv::Mat img = cv::imdecode(decode_buf, color_mode);
RCHECK(img.data != NULL) << "OpenCV decode fail:" << path;
cv::Mat res = img;
if (new_size > 0) {
if (center_crop) {
if (img.rows > img.cols) {
int margin = (img.rows - img.cols)/2;
img = img(cv::Range(margin, margin+img.cols), cv::Range(0, img.cols));
} else {
int margin = (img.cols - img.rows)/2;
img = img(cv::Range(0, img.rows), cv::Range(margin, margin + img.rows));
}
}
int interpolation_method = 1;
if (img.rows > img.cols) {
if (img.cols != new_size) {
interpolation_method = GetInterMethod(inter_method, img.cols, img.rows,
new_size,
img.rows * new_size / img.cols, prnd);
cv::resize(img, res, cv::Size(new_size,
img.rows * new_size / img.cols),
0, 0, interpolation_method);
} else {
res = img.clone();
}
} else {
if (img.rows != new_size) {
interpolation_method = GetInterMethod(inter_method, img.cols,
img.rows, new_size * img.cols / img.rows,
new_size, prnd);
cv::resize(img, res, cv::Size(new_size * img.cols / img.rows,
new_size), 0, 0, interpolation_method);
} else {
res = img.clone();
}
}
}
encode_buf.clear();
RCHECK(cv::imencode(encoding, res, encode_buf, encode_params));
// write buffer
size_t bsize = blob.size();
blob.resize(bsize + encode_buf.size());
memcpy(BeginPtr(blob) + bsize,
BeginPtr(encode_buf), encode_buf.size());
} else {
size_t bsize = blob.size();
blob.resize(bsize + decode_buf.size());
memcpy(BeginPtr(blob) + bsize,
BeginPtr(decode_buf), decode_buf.size());
}
writer.WriteRecord(BeginPtr(blob), blob.size());
// write header
++imcnt;
if (imcnt % 1000 == 0) {
LOG(INFO) << imcnt << " images processed, " << GetTime() - tstart << " sec elapsed";
}
}
LOG(INFO) << "Total: " << imcnt << " images processed, " << GetTime() - tstart << " sec elapsed";
delete fo;
delete flist;
}
} // namespace R
} // namespace mxnet