/*! | |
* Copyright (c) 2015 by Contributors | |
* \file crop-inl.h | |
* \brief | |
* \author Wei Wu | |
*/ | |
#ifndef MXNET_OPERATOR_CROP_INL_H_ | |
#define MXNET_OPERATOR_CROP_INL_H_ | |
#include <dmlc/logging.h> | |
#include <dmlc/parameter.h> | |
#include <mxnet/operator.h> | |
#include <cstring> | |
#include <map> | |
#include <string> | |
#include <vector> | |
#include <utility> | |
#include "./operator_common.h" | |
namespace mxnet { | |
namespace op { | |
namespace crop_enum { | |
enum CropOpInputs {kData, kCropLike}; | |
enum CropOpOutputs {kOut}; | |
} // namespace crop_enum | |
struct CropParam : public dmlc::Parameter<CropParam> { | |
int num_args; | |
TShape offset; | |
TShape h_w; | |
bool center_crop; | |
DMLC_DECLARE_PARAMETER(CropParam) { | |
DMLC_DECLARE_FIELD(num_args).set_range(1, 3) | |
.describe("Number of inputs for crop, if equals one, then we will use the h_w" | |
"for crop height and width, else if equals two, then we will use the height" | |
"and width of the second input symbol, we name crop_like here"); | |
int shape[] = {0, 0}; | |
DMLC_DECLARE_FIELD(offset).set_default(TShape(shape, shape + 2)) | |
.describe("crop offset coordinate: (y, x)"); | |
DMLC_DECLARE_FIELD(h_w).set_default(TShape(shape, shape + 2)) | |
.describe("crop height and width: (h, w)"); | |
DMLC_DECLARE_FIELD(center_crop).set_default(false) | |
.describe("If set to true, then it will use be the center_crop," | |
"or it will crop using the shape of crop_like"); | |
} | |
}; // struct CropParam | |
template<typename xpu> | |
class CropOp : public Operator { | |
public: | |
explicit CropOp(CropParam param) { | |
this->param_ = param; | |
} | |
virtual void Forward(const OpContext &ctx, | |
const std::vector<TBlob> &in_data, | |
const std::vector<OpReqType> &req, | |
const std::vector<TBlob> &out_data, | |
const std::vector<TBlob> &aux_args) { | |
using namespace mshadow; | |
using namespace mshadow::expr; | |
CHECK_EQ(static_cast<int>(in_data.size()), param_.num_args); | |
CHECK_EQ(out_data.size(), 1U); | |
CHECK_EQ(req[crop_enum::kOut], kWriteTo); | |
Stream<xpu> *s = ctx.get_stream<xpu>(); | |
Tensor<xpu, 4> data = in_data[crop_enum::kData].get<xpu, 4, real_t>(s); | |
Tensor<xpu, 4> out = out_data[crop_enum::kOut].get<xpu, 4, real_t>(s); | |
offset_hw_ = InferCropOfferset(data.shape_, out.shape_); | |
out = crop(data, Shape2(out.size(2), out.size(3)), offset_hw_[0], offset_hw_[1]); | |
} | |
// because the crop_like input is only used with it's shape, so we should be | |
// careful setting its backwrd grad value to zeros, so that it will not hurt | |
// the connection of crop_like. | |
virtual void Backward(const OpContext &ctx, | |
const std::vector<TBlob> &out_grad, | |
const std::vector<TBlob> &in_data, | |
const std::vector<TBlob> &out_data, | |
const std::vector<OpReqType> &req, | |
const std::vector<TBlob> &in_grad, | |
const std::vector<TBlob> &aux_states) { | |
using namespace mshadow; | |
using namespace mshadow::expr; | |
CHECK_EQ(in_grad.size(), static_cast<size_t>(param_.num_args)) << in_grad.size(); | |
CHECK_EQ(out_grad.size(), 1U) << out_grad.size(); | |
Stream<xpu> *s = ctx.get_stream<xpu>(); | |
Tensor<xpu, 4> grad = out_grad[crop_enum::kOut].get<xpu, 4, real_t>(s); | |
Tensor<xpu, 4> gdata = in_grad[crop_enum::kData].get<xpu, 4, real_t>(s); | |
if (param_.num_args > 1) { | |
// here backward grad is set to zero for crop_like | |
// however, this should only be done when num_args > 1, i.e., crop_like exists | |
Tensor<xpu, 4> gcrop_like = in_grad[crop_enum::kCropLike].get<xpu, 4, real_t>(s); | |
gcrop_like = (real_t)0.0f; | |
} | |
offset_hw_ = InferCropOfferset(gdata.shape_, grad.shape_); | |
gdata = (real_t)0.0f; | |
slice<3>(slice<2>(gdata, offset_hw_[0], offset_hw_[0]+grad.size(2)), | |
offset_hw_[1], offset_hw_[1]+grad.size(3)) = grad; | |
} | |
private: | |
CropParam param_; | |
std::vector<int> offset_hw_; | |
std::vector<int> InferCropOfferset(const mshadow::Shape<4> &data_shape, | |
const mshadow::Shape<4> &out_shape) { | |
std::vector<int> offset_hw; | |
CHECK_GE(data_shape[2], out_shape[2]) << | |
"data_shape'height should be larger than that of out_shape"; | |
CHECK_GE(data_shape[3], out_shape[3]) << | |
"data_shape'weight should be larger than that of out_shape"; | |
if (param_.center_crop) { | |
offset_hw.push_back(static_cast<int>((data_shape[2]-out_shape[2])/2)); | |
offset_hw.push_back(static_cast<int>((data_shape[3]-out_shape[3])/2)); | |
} else { | |
CHECK_GE(static_cast<int>(param_.offset[0]), 0) << | |
"offset[0] should be larger than 0"; | |
CHECK_LE(param_.offset[0], data_shape[2]-out_shape[2]) << | |
"offset[0] should be less than the residual space of height"; | |
CHECK_GE(static_cast<int>(param_.offset[1]), 0) << | |
"offset[1] should be larger than 0"; | |
CHECK_LE(param_.offset[1], data_shape[3]-out_shape[3]) << | |
"offset[1] should be less than the residual space of width"; | |
offset_hw.push_back(static_cast<int>(param_.offset[0])); | |
offset_hw.push_back(static_cast<int>(param_.offset[1])); | |
} | |
return offset_hw; | |
} | |
}; // class CropOp | |
template<typename xpu> | |
Operator *CreateOp(CropParam param); | |
#if DMLC_USE_CXX11 | |
class CropProp : public OperatorProperty { | |
public: | |
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { | |
param_.Init(kwargs); | |
} | |
std::map<std::string, std::string> GetParams() const override { | |
return param_.__DICT__(); | |
} | |
std::vector<std::string> ListArguments() const override { | |
// return {"data", "crop_like"}; | |
std::vector<std::string> ret; | |
for (int i = 0; i < param_.num_args; ++i) { | |
ret.push_back(std::string("arg") + std::to_string(i)); | |
} | |
return ret; | |
} | |
bool InferShape(std::vector<TShape> *in_shape, | |
std::vector<TShape> *out_shape, | |
std::vector<TShape> *aux_shape) const override { | |
using namespace mshadow; | |
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args)); | |
TShape data_shape = in_shape->at(crop_enum::kData); | |
if (data_shape.ndim() == 0) return false; | |
CHECK_EQ(data_shape.ndim(), 4U) << \ | |
"Input data should be 4D in batch-num_filter-y-x"; | |
std::vector<int> crop_shape; | |
if (param_.num_args == 1) { | |
CHECK_GE(static_cast<int>(param_.h_w[0]), 1) << | |
"the crop height(h_w[0]) should be larger than 1"; | |
CHECK_LE(static_cast<int>(param_.h_w[0]), static_cast<int>(data_shape[2])) << | |
"the crop height(h_w[0]) should be less than the input data's height"; | |
CHECK_GE(static_cast<int>(param_.h_w[1]), 1) << | |
"the crop width(h_w[1]) should be larger than 1"; | |
CHECK_LE(static_cast<int>(param_.h_w[1]), static_cast<int>(data_shape[3])) << | |
"the crop width(h_w[1]) should be less than the input data's width"; | |
crop_shape.push_back(param_.h_w[0]); | |
crop_shape.push_back(param_.h_w[1]); | |
} else if (param_.num_args == 2) { | |
TShape crop_like_shape = in_shape->at(crop_enum::kCropLike); | |
crop_shape.push_back(crop_like_shape[2]); | |
crop_shape.push_back(crop_like_shape[3]); | |
} | |
if (crop_shape.size() == 0) return false; | |
CHECK_EQ(crop_shape.size(), 2U) << \ | |
"Input crop_like should be 2D in height-width"; | |
out_shape->clear(); | |
data_shape[2] = crop_shape[0]; | |
data_shape[3] = crop_shape[1]; | |
out_shape->push_back(data_shape); | |
return true; | |
} | |
OperatorProperty* Copy() const override { | |
auto ptr = new CropProp(); | |
ptr->param_ = param_; | |
return ptr; | |
} | |
std::string TypeString() const override { | |
return "Crop"; | |
} | |
std::vector<int> DeclareBackwardDependency( | |
const std::vector<int> &out_grad, | |
const std::vector<int> &in_data, | |
const std::vector<int> &out_data) const override { | |
return out_grad; | |
} | |
Operator* CreateOperator(Context ctx) const override; | |
private: | |
CropParam param_; | |
}; // class CropProp | |
#endif // DMLC_USE_CXX11 | |
} // namespace op | |
} // namespace mxnet | |
#endif // MXNET_OPERATOR_CROP_INL_H_ |