/*! | |
* Copyright (c) 2017 by Contributors | |
* \file bilinear_Sampler-inl.h | |
* \brief | |
* \author Xu Dong | |
*/ | |
#ifndef MXNET_OPERATOR_BILINEAR_SAMPLER_INL_H_ | |
#define MXNET_OPERATOR_BILINEAR_SAMPLER_INL_H_ | |
#include <dmlc/logging.h> | |
#include <dmlc/parameter.h> | |
#include <mxnet/operator.h> | |
#include <vector> | |
#include <map> | |
#include <string> | |
#include <utility> | |
#include "./operator_common.h" | |
namespace mxnet { | |
namespace op { | |
namespace bs { | |
enum BilinearSamplerOpInputs {kData, kGrid}; | |
enum BilinearSamplerOpOutputs {kOut, kTmp}; | |
} | |
struct BilinearSamplerParam : public dmlc::Parameter<BilinearSamplerParam> { | |
DMLC_DECLARE_PARAMETER(BilinearSamplerParam) { | |
} | |
}; | |
template<typename xpu, typename DType> | |
class BilinearSamplerOp : public Operator { | |
public: | |
explicit BilinearSamplerOp(BilinearSamplerParam p) { | |
this->param_ = p; | |
} | |
virtual void Forward(const OpContext &ctx, | |
const std::vector<TBlob> &in_data, | |
const std::vector<OpReqType> &req, | |
const std::vector<TBlob> &out_data, | |
const std::vector<TBlob> &aux_args) { | |
using namespace mshadow; | |
using namespace mshadow::expr; | |
CHECK_EQ(req[bs::kOut], kWriteTo); | |
CHECK_EQ(in_data.size(), 2U); | |
Stream<xpu> *s = ctx.get_stream<xpu>(); | |
Tensor<xpu, 4, DType> data = in_data[bs::kData].get<xpu, 4, DType>(s); | |
Tensor<xpu, 4, DType> grid = in_data[bs::kGrid].get<xpu, 4, DType>(s); | |
Tensor<xpu, 4, DType> out = out_data[bs::kOut].get<xpu, 4, DType>(s); | |
BilinearSamplerForward(out, data, grid); | |
} | |
virtual void Backward(const OpContext &ctx, | |
const std::vector<TBlob> &out_grad, | |
const std::vector<TBlob> &in_data, | |
const std::vector<TBlob> &out_data, | |
const std::vector<OpReqType> &req, | |
const std::vector<TBlob> &in_grad, | |
const std::vector<TBlob> &aux_args) { | |
using namespace mshadow; | |
using namespace mshadow::expr; | |
CHECK_EQ(in_data.size(), 2U); | |
CHECK_NE(req[bs::kData], kWriteInplace); | |
CHECK_NE(req[bs::kGrid], kWriteInplace); | |
Stream<xpu> *s = ctx.get_stream<xpu>(); | |
Tensor<xpu, 4, DType> data = in_data[bs::kData].get<xpu, 4, DType>(s); | |
Tensor<xpu, 4, DType> grid = in_data[bs::kGrid].get<xpu, 4, DType>(s); | |
Tensor<xpu, 4, DType> gdata = in_grad[bs::kData].get<xpu, 4, DType>(s); | |
Tensor<xpu, 4, DType> ggrid = in_grad[bs::kGrid].get<xpu, 4, DType>(s); | |
Tensor<xpu, 4, DType> grad = out_grad[bs::kOut].get<xpu, 4, DType>(s); | |
if (req[bs::kData] != kNullOp && req[bs::kGrid] != kNullOp) { | |
if (req[bs::kData] == kWriteTo) { | |
gdata = scalar<DType>(0.0f); | |
} | |
if (req[bs::kGrid] == kWriteTo) { | |
ggrid = scalar<DType>(0.0f); | |
} | |
BilinearSamplerBackward(gdata, ggrid, grad, data, grid); | |
} else if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) { | |
return; | |
} else { | |
LOG(FATAL) << "Have not implemented the data req combinations! gdata_req=" | |
<< req[bs::kData] << " ggrid_req=" << req[bs::kGrid]; | |
} | |
} | |
private: | |
BilinearSamplerParam param_; | |
}; // class BilinearSamplerOp | |
template<typename xpu> | |
Operator* CreateOp(BilinearSamplerParam param, int dtype); | |
#if DMLC_USE_CXX11 | |
class BilinearSamplerProp : public OperatorProperty { | |
public: | |
int NumVisibleOutputs() const override { | |
return 1; | |
} | |
int NumOutputs() const override { | |
return 2; | |
} | |
std::vector<std::string> ListArguments() const override { | |
return {"data", "grid"}; | |
} | |
std::vector<std::string> ListOutputs() const override { | |
return {"output", "tmp"}; | |
} | |
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { | |
param_.Init(kwargs); | |
} | |
std::map<std::string, std::string> GetParams() const override { | |
return param_.__DICT__(); | |
} | |
bool InferShape(std::vector<TShape> *in_shape, | |
std::vector<TShape> *out_shape, | |
std::vector<TShape> *aux_shape) const override { | |
using namespace mshadow; | |
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, grid]"; | |
const TShape &dshape = (*in_shape)[bs::kData]; | |
const TShape &lshape = (*in_shape)[bs::kGrid]; | |
if (dshape.ndim() == 0) return false; | |
CHECK_EQ(dshape.ndim(), 4U) \ | |
<< "input data should be 4D in batch-num_filter-y-x"; | |
if (lshape.ndim() == 0) return false; | |
CHECK_EQ(lshape.ndim(), 4U) \ | |
<< "Sampler grid should be 4D in batch-2-y-x"; | |
CHECK_EQ(dshape[0], lshape[0]); | |
CHECK_EQ(lshape[1], 2U) << "incorrect grid shape[1], should be 2"; | |
// target height | |
CHECK_GT(lshape[2], 0U) \ | |
<< "incorrect grid_shape: " << lshape[2]; | |
// target width | |
CHECK_GT(lshape[3], 0U) \ | |
<< "incorrect grid_shape: " << lshape[3]; | |
out_shape->clear(); | |
// output_shape : (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]) | |
out_shape->push_back(dshape); | |
(*out_shape)[bs::kOut][2] = lshape[2]; | |
(*out_shape)[bs::kOut][3] = lshape[3]; | |
out_shape->push_back(Shape4(lshape[0], lshape[2], lshape[3], 2)); | |
return true; | |
} | |
bool InferType(std::vector<int> *in_type, | |
std::vector<int> *out_type, | |
std::vector<int> *aux_type) const override { | |
int dtype = -1; | |
for (size_t i = 0; i < in_type->size(); ++i) { | |
if (dtype == -1) { | |
dtype = in_type->at(i); | |
} else { | |
CHECK(in_type->at(i) == dtype || | |
in_type->at(i) == -1) << | |
"Non-uniform data type in BilinearSampler"; | |
} | |
} | |
if (dtype == -1) { | |
LOG(FATAL) << "Not enough information to infer type in BilinearSampler."; | |
return false; | |
} | |
size_t nin = this->ListArguments().size(); | |
in_type->clear(); | |
for (size_t i = 0; i < nin; ++i) in_type->push_back(dtype); | |
size_t naux = this->ListAuxiliaryStates().size(); | |
aux_type->clear(); | |
for (size_t i = 0; i < naux; ++i) aux_type->push_back(dtype); | |
size_t nout = this->ListOutputs().size(); | |
out_type->clear(); | |
for (size_t i = 0; i < nout; ++i) out_type->push_back(dtype); | |
return true; | |
} | |
OperatorProperty* Copy() const override { | |
auto ptr = new BilinearSamplerProp(); | |
ptr->param_ = param_; | |
return ptr; | |
} | |
std::string TypeString() const override { | |
return "BilinearSampler"; | |
} | |
std::vector<int> DeclareBackwardDependency( | |
const std::vector<int> &out_grad, | |
const std::vector<int> &in_data, | |
const std::vector<int> &out_data) const override { | |
return {out_grad[bs::kOut], | |
in_data[bs::kData], | |
out_data[bs::kTmp], | |
in_data[bs::kGrid]}; | |
} | |
Operator* CreateOperator(Context ctx) const override { | |
LOG(FATAL) << "Not Implemented."; | |
return NULL; | |
} | |
Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape, | |
std::vector<int> *in_type) const override; | |
private: | |
BilinearSamplerParam param_; | |
}; // class BilinearSamplerProp | |
#endif // DMLC_USE_CXX11 | |
} // namespace op | |
} // namespace mxnet | |
#endif // MXNET_OPERATOR_BILINEAR_SAMPLER_INL_H_ |