blob: 1547fdee6e2e2d54f372081b9d01154a9cdd7b46 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file instance_norm.cc
* \brief
* \author Sebastian Bodenstein
*/
#include "./instance_norm-inl.h"
namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(InstanceNormParam);
struct InstanceNormGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
std::vector<nnvm::NodeEntry> out_data;
out_data.reserve(n->num_outputs());
for (size_t i = 0; i < n->num_outputs(); ++i)
out_data.emplace_back(n, i, 0);
std::vector<nnvm::NodeEntry> heads;
heads.reserve(5);
heads.emplace_back(ograds.at(instance_norm::kOut));
heads.emplace_back(out_data.at(instance_norm::kMean));
heads.emplace_back(out_data.at(instance_norm::kVar));
heads.emplace_back(n->inputs.at(instance_norm::kData));
heads.emplace_back(n->inputs.at(instance_norm::kGamma));
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};
NNVM_REGISTER_OP(InstanceNorm)
.add_alias("_npx_instance_norm")
.describe(R"code(Applies instance normalization to the n-dimensional input array.
This operator takes an n-dimensional input array where (n>2) and normalizes
the input using the following formula:
.. math::
out = \frac{x - mean[data]}{ \sqrt{Var[data] + \epsilon}} * gamma + beta
This layer is similar to batch normalization layer (`BatchNorm`)
with two differences: first, the normalization is
carried out per example (instance), not over a batch. Second, the
same normalization is applied both at test and train time. This
operation is also known as `contrast normalization`.
If the input data is of shape [batch, channel, spacial_dim1, spacial_dim2, ...],
`gamma` and `beta` parameters must be vectors of shape [channel].
This implementation is based on this paper [1]_
.. [1] Instance Normalization: The Missing Ingredient for Fast Stylization,
D. Ulyanov, A. Vedaldi, V. Lempitsky, 2016 (arXiv:1607.08022v2).
Examples::
// Input of shape (2,1,2)
x = [[[ 1.1, 2.2]],
[[ 3.3, 4.4]]]
// gamma parameter of length 1
gamma = [1.5]
// beta parameter of length 1
beta = [0.5]
// Instance normalization is calculated with the above formula
InstanceNorm(x,gamma,beta) = [[[-0.997527 , 1.99752665]],
[[-0.99752653, 1.99752724]]]
)code" ADD_FILELINE)
.add_argument("data", "NDArray-or-Symbol",
"An n-dimensional input array (n > 2) of the form [batch, "
"channel, spatial_dim1, spatial_dim2, ...].")
.add_argument("gamma", "NDArray-or-Symbol",
"A vector of length \'channel\', which multiplies the "
"normalized input.")
.add_argument("beta", "NDArray-or-Symbol",
"A vector of length \'channel\', which is added to the "
"product of the normalized input and the weight.")
.add_arguments(InstanceNormParam::__FIELDS__())
.set_num_inputs(3)
.set_num_outputs(3)
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "gamma", "beta"};
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
[](const NodeAttrs& attrs) { return 1; })
.set_attr_parser(ParamParser<InstanceNormParam>)
.set_attr<mxnet::FInferShape>("FInferShape", InstanceNormShape)
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<nnvm::FGradient>("FGradient", InstanceNormGrad{"_backward_instance_norm"})
.set_attr<FCompute>("FCompute<cpu>", InstanceNormForward<cpu>);
NNVM_REGISTER_OP(_backward_instance_norm)
.set_num_inputs(5)
.set_num_outputs(3)
.set_attr_parser(ParamParser<InstanceNormParam>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", InstanceNormBackward<cpu>);
} // namespace op
} // namespace mxnet