blob: f0aa246ecae61ecb935f1ca6017919a01f80e751 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2017 by Contributors
* \file sample_multinomial_op.h
* \brief Operator for sampling from multinomial distributions
*/
#include "./sample_multinomial_op.h"
namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(SampleMultinomialParam);
NNVM_REGISTER_OP(_sample_multinomial)
.add_alias("sample_multinomial")
.add_alias("_npx__random_categorical")
.describe(R"code(Concurrent sampling from multiple multinomial distributions.
*data* is an *n* dimensional array whose last dimension has length *k*, where
*k* is the number of possible outcomes of each multinomial distribution. This
operator will draw *shape* samples from each distribution. If shape is empty
one sample will be drawn from each distribution.
If *get_prob* is true, a second array containing log likelihood of the drawn
samples will also be returned. This is usually used for reinforcement learning
where you can provide reward as head gradient for this array to estimate
gradient.
Note that the input distribution must be normalized, i.e. *data* must sum to
1 along its last axis.
Examples::
probs = [[0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0]]
// Draw a single sample for each distribution
sample_multinomial(probs) = [3, 0]
// Draw a vector containing two samples for each distribution
sample_multinomial(probs, shape=(2)) = [[4, 2],
[0, 0]]
// requests log likelihood
sample_multinomial(probs, get_prob=True) = [2, 1], [0.2, 0.3]
)code")
.set_num_inputs(1)
.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
return param.get_prob ? 2U : 1U;
})
.set_attr_parser(ParamParser<SampleMultinomialParam>)
.set_attr<mxnet::FInferShape>("FInferShape", SampleMultinomialOpShape)
.set_attr<nnvm::FInferType>("FInferType", SampleMultinomialOpType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<ResourceRequest>{
ResourceRequest::kRandom, ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(n->attrs.parsed);
if (param.get_prob) {
return MakeGradNode("_backward_sample_multinomial", n,
{ograds[1], n->inputs[0], nnvm::NodeEntry{n, 0, 0}},
std::unordered_map<std::string, std::string>());
} else {
return MakeZeroGradNodes(n, ograds);
}
})
.set_attr<FCompute>("FCompute<cpu>", SampleMultinomialForward<cpu>)
.add_argument("data", "NDArray-or-Symbol",
"Distribution probabilities. Must sum to one on the last axis.")
.add_arguments(SampleMultinomialParam::__FIELDS__());
struct SampleMultinomialBackwardCPUKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, index_t K, index_t M,
DType* ograd, DType* dist, IType* out,
DType* igrad) {
for (index_t j = 0; j < M; ++j) {
igrad[i*K + static_cast<size_t>(out[i*M + j])] +=
ograd[i*M + j] / dist[i*K + static_cast<size_t>(out[i*M + j])];
}
}
};
NNVM_REGISTER_OP(_backward_sample_multinomial)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>",
SampleMultinomialBackward<SampleMultinomialBackwardCPUKernel, cpu>);
} // namespace op
} // namespace mxnet