blob: 5b5adbf15874295c4340d4e2b825933cfc0c4cf4 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file crop-inl.h
* \brief
* \author Wei Wu
*/
#ifndef MXNET_OPERATOR_CROP_INL_H_
#define MXNET_OPERATOR_CROP_INL_H_
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <cstring>
#include <map>
#include <string>
#include <vector>
#include <utility>
#include "./operator_common.h"
namespace mxnet {
namespace op {
namespace crop_enum {
enum CropOpInputs {kData, kCropLike};
enum CropOpOutputs {kOut};
} // namespace crop_enum
struct CropParam : public dmlc::Parameter<CropParam> {
int num_args;
TShape offset;
TShape h_w;
bool center_crop;
DMLC_DECLARE_PARAMETER(CropParam) {
DMLC_DECLARE_FIELD(num_args).set_range(1, 3)
.describe("Number of inputs for crop, if equals one, then we will use the h_w"
"for crop height and width, else if equals two, then we will use the height"
"and width of the second input symbol, we name crop_like here");
int shape[] = {0, 0};
DMLC_DECLARE_FIELD(offset).set_default(TShape(shape, shape + 2))
.describe("crop offset coordinate: (y, x)");
DMLC_DECLARE_FIELD(h_w).set_default(TShape(shape, shape + 2))
.describe("crop height and width: (h, w)");
DMLC_DECLARE_FIELD(center_crop).set_default(false)
.describe("If set to true, then it will use be the center_crop,"
"or it will crop using the shape of crop_like");
}
}; // struct CropParam
template<typename xpu>
class CropOp : public Operator {
public:
explicit CropOp(CropParam param) {
this->param_ = param;
}
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(static_cast<int>(in_data.size()), param_.num_args);
CHECK_EQ(out_data.size(), 1U);
CHECK_EQ(req[crop_enum::kOut], kWriteTo);
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4> data = in_data[crop_enum::kData].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> out = out_data[crop_enum::kOut].get<xpu, 4, real_t>(s);
offset_hw_ = InferCropOfferset(data.shape_, out.shape_);
out = crop(data, Shape2(out.size(2), out.size(3)), offset_hw_[0], offset_hw_[1]);
}
// because the crop_like input is only used with it's shape, so we should be
// careful setting its backwrd grad value to zeros, so that it will not hurt
// the connection of crop_like.
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_states) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(in_grad.size(), static_cast<size_t>(param_.num_args)) << in_grad.size();
CHECK_EQ(out_grad.size(), 1U) << out_grad.size();
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4> grad = out_grad[crop_enum::kOut].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> gdata = in_grad[crop_enum::kData].get<xpu, 4, real_t>(s);
if (param_.num_args > 1) {
// here backward grad is set to zero for crop_like
// however, this should only be done when num_args > 1, i.e., crop_like exists
Tensor<xpu, 4> gcrop_like = in_grad[crop_enum::kCropLike].get<xpu, 4, real_t>(s);
gcrop_like = (real_t)0.0f;
}
offset_hw_ = InferCropOfferset(gdata.shape_, grad.shape_);
gdata = (real_t)0.0f;
slice<3>(slice<2>(gdata, offset_hw_[0], offset_hw_[0]+grad.size(2)),
offset_hw_[1], offset_hw_[1]+grad.size(3)) = grad;
}
private:
CropParam param_;
std::vector<int> offset_hw_;
std::vector<int> InferCropOfferset(const mshadow::Shape<4> &data_shape,
const mshadow::Shape<4> &out_shape) {
std::vector<int> offset_hw;
CHECK_GE(data_shape[2], out_shape[2]) <<
"data_shape'height should be larger than that of out_shape";
CHECK_GE(data_shape[3], out_shape[3]) <<
"data_shape'weight should be larger than that of out_shape";
if (param_.center_crop) {
offset_hw.push_back(static_cast<int>((data_shape[2]-out_shape[2])/2));
offset_hw.push_back(static_cast<int>((data_shape[3]-out_shape[3])/2));
} else {
CHECK_GE(static_cast<int>(param_.offset[0]), 0) <<
"offset[0] should be larger than 0";
CHECK_LE(param_.offset[0], data_shape[2]-out_shape[2]) <<
"offset[0] should be less than the residual space of height";
CHECK_GE(static_cast<int>(param_.offset[1]), 0) <<
"offset[1] should be larger than 0";
CHECK_LE(param_.offset[1], data_shape[3]-out_shape[3]) <<
"offset[1] should be less than the residual space of width";
offset_hw.push_back(static_cast<int>(param_.offset[0]));
offset_hw.push_back(static_cast<int>(param_.offset[1]));
}
return offset_hw;
}
}; // class CropOp
template<typename xpu>
Operator *CreateOp(CropParam param);
#if DMLC_USE_CXX11
class CropProp : public OperatorProperty {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
}
std::map<std::string, std::string> GetParams() const override {
return param_.__DICT__();
}
std::vector<std::string> ListArguments() const override {
// return {"data", "crop_like"};
std::vector<std::string> ret;
for (int i = 0; i < param_.num_args; ++i) {
ret.push_back(std::string("arg") + std::to_string(i));
}
return ret;
}
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
using namespace mshadow;
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
TShape data_shape = in_shape->at(crop_enum::kData);
if (data_shape.ndim() == 0) return false;
CHECK_EQ(data_shape.ndim(), 4U) << \
"Input data should be 4D in batch-num_filter-y-x";
std::vector<int> crop_shape;
if (param_.num_args == 1) {
CHECK_GE(static_cast<int>(param_.h_w[0]), 1) <<
"the crop height(h_w[0]) should be larger than 1";
CHECK_LE(static_cast<int>(param_.h_w[0]), static_cast<int>(data_shape[2])) <<
"the crop height(h_w[0]) should be less than the input data's height";
CHECK_GE(static_cast<int>(param_.h_w[1]), 1) <<
"the crop width(h_w[1]) should be larger than 1";
CHECK_LE(static_cast<int>(param_.h_w[1]), static_cast<int>(data_shape[3])) <<
"the crop width(h_w[1]) should be less than the input data's width";
crop_shape.push_back(param_.h_w[0]);
crop_shape.push_back(param_.h_w[1]);
} else if (param_.num_args == 2) {
TShape crop_like_shape = in_shape->at(crop_enum::kCropLike);
crop_shape.push_back(crop_like_shape[2]);
crop_shape.push_back(crop_like_shape[3]);
}
if (crop_shape.size() == 0) return false;
CHECK_EQ(crop_shape.size(), 2U) << \
"Input crop_like should be 2D in height-width";
out_shape->clear();
data_shape[2] = crop_shape[0];
data_shape[3] = crop_shape[1];
out_shape->push_back(data_shape);
return true;
}
OperatorProperty* Copy() const override {
auto ptr = new CropProp();
ptr->param_ = param_;
return ptr;
}
std::string TypeString() const override {
return "Crop";
}
std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const override {
return out_grad;
}
Operator* CreateOperator(Context ctx) const override;
private:
CropParam param_;
}; // class CropProp
#endif // DMLC_USE_CXX11
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_CROP_INL_H_