blob: fe91a143c23e942722bf0bc39bc64296a11b4c6b [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file spatial_transformer.cc
* \brief
* \author Wei Wu
*/
#include "./spatial_transformer-inl.h"
namespace mshadow {
template<typename DType>
inline void BilinearSamplingForward(const Tensor<cpu, 4, DType> &output,
const Tensor<cpu, 4, DType> &input,
const Tensor<cpu, 3, DType> grid_src) {
DType *out = output.dptr_;
const DType *data = input.dptr_;
const DType *grid = grid_src.dptr_;
int o_n = output.size(0), o_c = output.size(1), o_h = output.size(2), o_w = output.size(3);
int i_c = input.size(1), i_h = input.size(2), i_w = input.size(3);
for (index_t n = 0; n < static_cast<index_t>(o_n); ++n) {
for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
for (index_t h = 0; h < static_cast<index_t>(o_h); ++h) {
for (index_t w = 0; w < o_w; ++w) {
index_t out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
index_t grid_index = n * o_h * o_w * 2 + h * o_w + w;
DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2;
DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2;
index_t top_left_y = std::min(i_h, std::max(0, static_cast<int>(floor(y_real))));
index_t top_left_x = std::min(i_w, std::max(0, static_cast<int>(floor(x_real))));
DType top_left_y_w = 1.0 - (y_real - top_left_y);
DType top_left_x_w = 1.0 - (x_real - top_left_x);
index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
DType top_left_v = *(data + data_index);
DType top_right_v = *(data + data_index + 1);
DType bottom_left_v = *(data + data_index + i_w);
DType bottom_right_v = *(data + data_index + i_w + 1);
*(out+out_index) = top_left_v * top_left_y_w * top_left_x_w +
top_right_v * top_left_y_w * (1.0 - top_left_x_w) +
bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w +
bottom_right_v * (1.0 - top_left_y_w) * (1.0 - top_left_x_w);
}
}
}
}
}
template<typename DType>
inline void BilinearSamplingBackward(const Tensor<cpu, 4, DType> &input_grad,
const Tensor<cpu, 3, DType> &grid_src_data,
const Tensor<cpu, 4, DType> &output_grad,
const Tensor<cpu, 4, DType> &input_data) {
DType *g_input = input_grad.dptr_;
DType *grid_src = grid_src_data.dptr_;
const DType *grad = output_grad.dptr_;
const DType *data = input_data.dptr_;
int o_n = output_grad.size(0), o_c = output_grad.size(1),
o_h = output_grad.size(2), o_w = output_grad.size(3);
int i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3);
for (index_t n = 0; n < static_cast<index_t>(o_n); ++n) {
for (index_t h = 0; h < static_cast<index_t>(o_h); ++h) {
for (index_t w = 0; w < static_cast<index_t>(o_w); ++w) {
DType top_left_y_gw = 0.0;
DType top_left_x_gw = 0.0;
index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2;
DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2;
index_t top_left_y = std::min(i_h, std::max(0, static_cast<int>(floor(y_real))));
index_t top_left_x = std::min(i_w, std::max(0, static_cast<int>(floor(x_real))));
DType top_left_y_w = 1.0 - (y_real - top_left_y);
DType top_left_x_w = 1.0 - (x_real - top_left_x);
for (index_t c = 0; c < o_c; ++c) {
index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w
+ top_left_x;
// calc 4 vertex value in input data
DType top_left_v = *(data + data_index);
DType top_right_v = *(data + data_index + 1);
DType bottom_left_v = *(data + data_index + i_w);
DType bottom_right_v = *(data + data_index + i_w + 1);
// calc input grad
*(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
*(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w
* (1.0 - top_left_x_w);
*(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w)
* top_left_x_w;
*(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w)
* (1.0 - top_left_x_w);
// calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v +
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
* top_left_x_w);
top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v +
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
* top_left_y_w);
}
// calc grid_src grad
*(grid_src + grid_src_index + o_h * o_w) = top_left_y_gw * (i_h - 1) / 2;
*(grid_src + grid_src_index) = top_left_x_gw * (i_w - 1) / 2;
}
}
}
}
} // namespace mshadow
namespace mxnet {
namespace op {
template<>
Operator* CreateOp<cpu>(SpatialTransformerParam param, int dtype) {
Operator *op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new SpatialTransformerOp<cpu, DType>(param);
})
return op;
}
Operator *SpatialTransformerProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
std::vector<TShape> out_shape, aux_shape;
std::vector<int> out_type, aux_type;
CHECK(InferType(in_type, &out_type, &aux_type));
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
}
DMLC_REGISTER_PARAMETER(SpatialTransformerParam);
MXNET_REGISTER_OP_PROPERTY(SpatialTransformer, SpatialTransformerProp)
.add_argument("data", "NDArray-or-Symbol",
"Input data to the SpatialTransformerOp.")
.add_argument("loc", "NDArray-or-Symbol",
"localisation net, the output dim should be 6 when transform_type "
"is affine. You shold initialize the weight and bias with identity tranform.")
.add_arguments(SpatialTransformerParam::__FIELDS__())
.describe("Applies a spatial transformer to input feature map.");
} // namespace op
} // namespace mxnet