| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * Copyright (c) 2015 by Contributors |
| * \file crop-inl.h |
| * \brief |
| * \author Wei Wu |
| */ |
| #ifndef MXNET_OPERATOR_CROP_INL_H_ |
| #define MXNET_OPERATOR_CROP_INL_H_ |
| #include <dmlc/logging.h> |
| #include <dmlc/parameter.h> |
| #include <mxnet/operator.h> |
| #include <cstring> |
| #include <map> |
| #include <string> |
| #include <vector> |
| #include <utility> |
| #include "./operator_common.h" |
| |
| namespace mxnet { |
| namespace op { |
| |
| namespace crop_enum { |
| enum CropOpInputs {kData, kCropLike}; |
| enum CropOpOutputs {kOut}; |
| } // namespace crop_enum |
| |
| struct CropParam : public dmlc::Parameter<CropParam> { |
| int num_args; |
| mxnet::TShape offset; |
| mxnet::TShape h_w; |
| bool center_crop; |
| DMLC_DECLARE_PARAMETER(CropParam) { |
| DMLC_DECLARE_FIELD(num_args).set_range(1, 3) |
| .describe("Number of inputs for crop, if equals one, then we will use the h_w" |
| "for crop height and width, else if equals two, then we will use the height" |
| "and width of the second input symbol, we name crop_like here"); |
| int shape[] = {0, 0}; |
| DMLC_DECLARE_FIELD(offset).set_default(mxnet::TShape(shape, shape + 2)) |
| .describe("crop offset coordinate: (y, x)"); |
| DMLC_DECLARE_FIELD(h_w).set_default(mxnet::TShape(shape, shape + 2)) |
| .describe("crop height and width: (h, w)"); |
| DMLC_DECLARE_FIELD(center_crop).set_default(false) |
| .describe("If set to true, then it will use be the center_crop," |
| "or it will crop using the shape of crop_like"); |
| } |
| }; // struct CropParam |
| |
| template<typename xpu> |
| class CropOp : public Operator { |
| public: |
| explicit CropOp(CropParam param) { |
| this->param_ = param; |
| } |
| |
| virtual void Forward(const OpContext &ctx, |
| const std::vector<TBlob> &in_data, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &out_data, |
| const std::vector<TBlob> &aux_args) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| CHECK_EQ(static_cast<int>(in_data.size()), param_.num_args); |
| CHECK_EQ(out_data.size(), 1U); |
| CHECK_EQ(req[crop_enum::kOut], kWriteTo); |
| Stream<xpu> *s = ctx.get_stream<xpu>(); |
| Tensor<xpu, 4> data = in_data[crop_enum::kData].get<xpu, 4, real_t>(s); |
| Tensor<xpu, 4> out = out_data[crop_enum::kOut].get<xpu, 4, real_t>(s); |
| offset_hw_ = InferCropOfferset(data.shape_, out.shape_); |
| out = crop(data, Shape2(out.size(2), out.size(3)), offset_hw_[0], offset_hw_[1]); |
| } |
| |
| // because the crop_like input is only used with it's shape, so we should be |
| // careful setting its backwrd grad value to zeros, so that it will not hurt |
| // the connection of crop_like. |
| virtual void Backward(const OpContext &ctx, |
| const std::vector<TBlob> &out_grad, |
| const std::vector<TBlob> &in_data, |
| const std::vector<TBlob> &out_data, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &in_grad, |
| const std::vector<TBlob> &aux_states) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| CHECK_EQ(in_grad.size(), static_cast<size_t>(param_.num_args)) << in_grad.size(); |
| CHECK_EQ(out_grad.size(), 1U) << out_grad.size(); |
| Stream<xpu> *s = ctx.get_stream<xpu>(); |
| Tensor<xpu, 4> grad = out_grad[crop_enum::kOut].get<xpu, 4, real_t>(s); |
| Tensor<xpu, 4> gdata = in_grad[crop_enum::kData].get<xpu, 4, real_t>(s); |
| if (param_.num_args > 1) { |
| // here backward grad is set to zero for crop_like |
| // however, this should only be done when num_args > 1, i.e., crop_like exists |
| Tensor<xpu, 4> gcrop_like = in_grad[crop_enum::kCropLike].get<xpu, 4, real_t>(s); |
| gcrop_like = (real_t)0.0f; |
| } |
| offset_hw_ = InferCropOfferset(gdata.shape_, grad.shape_); |
| gdata = (real_t)0.0f; |
| slice<3>(slice<2>(gdata, offset_hw_[0], offset_hw_[0]+grad.size(2)), |
| offset_hw_[1], offset_hw_[1]+grad.size(3)) = grad; |
| } |
| |
| private: |
| CropParam param_; |
| std::vector<int> offset_hw_; |
| std::vector<int> InferCropOfferset(const mshadow::Shape<4> &data_shape, |
| const mshadow::Shape<4> &out_shape) { |
| std::vector<int> offset_hw; |
| CHECK_GE(data_shape[2], out_shape[2]) << |
| "data_shape'height should be larger than that of out_shape"; |
| CHECK_GE(data_shape[3], out_shape[3]) << |
| "data_shape'weight should be larger than that of out_shape"; |
| if (param_.center_crop) { |
| offset_hw.push_back(static_cast<int>((data_shape[2]-out_shape[2])/2)); |
| offset_hw.push_back(static_cast<int>((data_shape[3]-out_shape[3])/2)); |
| } else { |
| CHECK_GE(static_cast<int>(param_.offset[0]), 0) << |
| "offset[0] should be larger than 0"; |
| CHECK_LE(param_.offset[0], data_shape[2]-out_shape[2]) << |
| "offset[0] should be less than the residual space of height"; |
| CHECK_GE(static_cast<int>(param_.offset[1]), 0) << |
| "offset[1] should be larger than 0"; |
| CHECK_LE(param_.offset[1], data_shape[3]-out_shape[3]) << |
| "offset[1] should be less than the residual space of width"; |
| offset_hw.push_back(static_cast<int>(param_.offset[0])); |
| offset_hw.push_back(static_cast<int>(param_.offset[1])); |
| } |
| return offset_hw; |
| } |
| }; // class CropOp |
| |
| template<typename xpu> |
| Operator *CreateOp(CropParam param); |
| |
| #if DMLC_USE_CXX11 |
| class CropProp : public OperatorProperty { |
| public: |
| void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { |
| param_.Init(kwargs); |
| } |
| |
| std::map<std::string, std::string> GetParams() const override { |
| return param_.__DICT__(); |
| } |
| |
| std::vector<std::string> ListArguments() const override { |
| // return {"data", "crop_like"}; |
| std::vector<std::string> ret; |
| for (int i = 0; i < param_.num_args; ++i) { |
| ret.push_back(std::string("arg") + std::to_string(i)); |
| } |
| return ret; |
| } |
| |
| bool InferShape(mxnet::ShapeVector *in_shape, |
| mxnet::ShapeVector *out_shape, |
| mxnet::ShapeVector *aux_shape) const override { |
| using namespace mshadow; |
| CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args)); |
| mxnet::TShape data_shape = in_shape->at(crop_enum::kData); |
| if (data_shape.ndim() == 0) return false; |
| CHECK_EQ(data_shape.ndim(), 4U) << \ |
| "Input data should be 4D in batch-num_filter-y-x"; |
| std::vector<int> crop_shape; |
| if (param_.num_args == 1) { |
| CHECK_GE(static_cast<int>(param_.h_w[0]), 1) << |
| "the crop height(h_w[0]) should be larger than 1"; |
| CHECK_LE(static_cast<int>(param_.h_w[0]), static_cast<int>(data_shape[2])) << |
| "the crop height(h_w[0]) should be less than the input data's height"; |
| CHECK_GE(static_cast<int>(param_.h_w[1]), 1) << |
| "the crop width(h_w[1]) should be larger than 1"; |
| CHECK_LE(static_cast<int>(param_.h_w[1]), static_cast<int>(data_shape[3])) << |
| "the crop width(h_w[1]) should be less than the input data's width"; |
| crop_shape.push_back(param_.h_w[0]); |
| crop_shape.push_back(param_.h_w[1]); |
| } else if (param_.num_args == 2) { |
| mxnet::TShape crop_like_shape = in_shape->at(crop_enum::kCropLike); |
| crop_shape.push_back(crop_like_shape[2]); |
| crop_shape.push_back(crop_like_shape[3]); |
| } |
| if (crop_shape.size() == 0) return false; |
| CHECK_EQ(crop_shape.size(), 2U) << \ |
| "Input crop_like should be 2D in height-width"; |
| out_shape->clear(); |
| data_shape[2] = crop_shape[0]; |
| data_shape[3] = crop_shape[1]; |
| out_shape->push_back(data_shape); |
| return true; |
| } |
| |
| OperatorProperty* Copy() const override { |
| auto ptr = new CropProp(); |
| ptr->param_ = param_; |
| return ptr; |
| } |
| |
| std::string TypeString() const override { |
| return "Crop"; |
| } |
| |
| std::vector<int> DeclareBackwardDependency( |
| const std::vector<int> &out_grad, |
| const std::vector<int> &in_data, |
| const std::vector<int> &out_data) const override { |
| return out_grad; |
| } |
| |
| Operator* CreateOperator(Context ctx) const override; |
| |
| private: |
| CropParam param_; |
| }; // class CropProp |
| #endif // DMLC_USE_CXX11 |
| } // namespace op |
| } // namespace mxnet |
| #endif // MXNET_OPERATOR_CROP_INL_H_ |