blob: 5c0eb0bedcc88a6d3b6211a6041517c25341dd77 [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file multisample_op.cc
* \brief CPU-implementation of multi-sampling operators
*/
#include "./multisample_op.h"
namespace mxnet {
namespace op {
struct UniformSampler {
template<typename DType>
struct Sampler{
std::mt19937 rnd_;
// Ensure that half_t is handled correctly.
typedef typename std::conditional<std::is_floating_point<DType>::value,
DType, double>::type FType;
typedef typename std::conditional<std::is_integral<DType>::value,
std::uniform_int_distribution<DType>,
std::uniform_real_distribution<FType>>::type GType;
GType gen_;
template<typename PType>
Sampler(PType a, PType b, int seed): rnd_(seed), gen_(a, b) { }
MSHADOW_XINLINE DType operator()() { return gen_(rnd_); }
};
};
struct NormalSampler {
template<typename DType>
struct Sampler{
std::mt19937 rnd_;
typedef typename std::conditional<std::is_floating_point<DType>::value,
DType, double>::type GType;
std::normal_distribution<GType> gen_;
template<typename PType>
Sampler(PType mu, PType sigma, int seed): rnd_(seed), gen_(mu, sigma)
{ CHECK_EQ(std::is_floating_point<DType>::value, true)
<< "Normal distribution must have floating point target type"; };
MSHADOW_XINLINE DType operator()() { return gen_(rnd_); }
};
};
struct GammaSampler {
template<typename DType>
struct Sampler{
std::mt19937 rnd_;
// Avoid problems with static check during compilation for integral types.
typedef typename std::conditional<std::is_floating_point<DType>::value,
DType, double>::type GType;
std::gamma_distribution<GType> gen_;
template<typename PType>
Sampler(PType alpha, PType beta, int seed): rnd_(seed), gen_(alpha, beta)
{ CHECK_EQ(std::is_floating_point<DType>::value, true)
<< "Gamma distribution must have floating point target type"; };
MSHADOW_XINLINE DType operator()() { return gen_(rnd_); }
};
};
struct ExponentialSampler {
template<typename DType>
struct Sampler{
std::mt19937 rnd_;
// Avoid problems with static check during compilation for integral types.
typedef typename std::conditional<std::is_floating_point<DType>::value,
DType, double>::type GType;
std::exponential_distribution<GType> gen_;
template<typename PType>
Sampler(PType lambda, PType , int seed): rnd_(seed), gen_(lambda)
{ CHECK_EQ(std::is_floating_point<DType>::value, true)
<< "Exponential distribution must have floating point target type"; };
MSHADOW_XINLINE DType operator()() { return gen_(rnd_); }
};
};
struct PoissonSampler {
template<typename DType>
struct Sampler{
std::mt19937 rnd_;
// Allow sampling of a Poisson distribution also to output floating point types.
typedef typename std::conditional<std::is_integral<DType>::value,
DType, int>::type GType;
std::poisson_distribution<GType> gen_;
template<typename PType>
Sampler(PType lambda, PType , int seed): rnd_(seed), gen_(lambda) { }
MSHADOW_XINLINE DType operator()() { return static_cast<DType>(gen_(rnd_)); }
};
};
// Negative binomial distribution as defined in C++ standard library
struct NegativeBinomialSampler {
template<typename DType>
struct Sampler{
std::mt19937 rnd_;
// Allow sampling of a negative binomial distribution also to output floating point types.
typedef typename std::conditional<std::is_integral<DType>::value, DType, int>::type GType;
std::negative_binomial_distribution<GType> gen_;
template<typename PType>
Sampler(PType k, PType p, int seed): rnd_(seed), gen_(k, p) {}
MSHADOW_XINLINE DType operator()() { return static_cast<DType>(gen_(rnd_)); }
};
};
// Generalized form of the negative binomial distribution which is generated by
// a poisson-gamma mixture: X ~ NegBin(mu, alpha) corresponds to
// X ~ Poisson(Gamma(1/alpha,mu*alpha))
struct GeneralizedNegativeBinomialSampler {
template<typename DType>
struct Sampler {
// We allow the boundary case where the negative binomial equals the Poisson distribution
bool poisson_;
double mu_;
std::mt19937 rnd_;
// Realize the negative binomial by a Poisson distribution over a gamma distributed mean.
std::gamma_distribution<> gen_;
template<typename PType>
Sampler(PType mu, PType alpha, int seed): poisson_(alpha == 0.0), mu_(mu), rnd_(seed),
gen_((alpha == PType(0) ? PType(1) : PType(1)/alpha), mu*alpha) {}
// Allow sampling of a Poisson distribution also to output floating point types.
typedef typename std::conditional<std::is_integral<DType>::value, DType, int>::type GType;
MSHADOW_XINLINE DType operator()() { return static_cast<DType>(
std::poisson_distribution<GType>(poisson_ ? mu_ : gen_(rnd_))(rnd_)); }
};
};
DMLC_REGISTER_PARAMETER(MultiSampleParam);
#define MXNET_OPERATOR_REGISTER_SAMPLING(distr, sampler, num_inputs, \
input_name_1, input_name_2, description) \
NNVM_REGISTER_OP(sample_##distr) \
.MXNET_DESCRIBE("Multi-sampling from " description "." \
" The parameters of the distributions are provided as input tensor(s)." \
" Let \"[s]\" be the shape of the input tensor(s), \"n\" be the dimension of [s], \"[t]\"" \
" be the shape specified as the parameter of the operator, and \"m\" be the dimension" \
" of [t]. Then the output will be a (n+m)-dimensional tensor with shape [s]x[t]. For any" \
" valid n-dimensional index \"i\" with respect to the input tensor(s), output[i] will be" \
" an m-dimensional tensor that holds randomly drawn samples from the distribution which" \
" is parameterized by the input values at index i. If the shape parameter of the operator" \
" is not set, then one sample will be drawn per distribution and the output tensor has" \
" the same dimensions as the input tensor(s).") \
.set_num_inputs(num_inputs) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<MultiSampleParam>) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
std::vector<std::string> v = {input_name_1, input_name_2}; v.resize(num_inputs); return v; \
}) \
.set_attr<nnvm::FInferShape>("FInferShape", MultiSampleOpShape) \
.set_attr<nnvm::FInferType>("FInferType", MultiSampleOpType) \
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>(1, ResourceRequest::kRandom); \
}) \
.set_attr<FCompute>("FCompute<cpu>", MultiSampleOpForward<cpu, sampler>) \
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) \
.add_arguments(MultiSampleParam::__FIELDS__()) \
.add_argument(input_name_1, "NDArray", "")
#define MXNET_OPERATOR_REGISTER_SAMPLING1(distr, sampler, input_name, description) \
MXNET_OPERATOR_REGISTER_SAMPLING(distr, sampler, 1, input_name, input_name, description);
#define MXNET_OPERATOR_REGISTER_SAMPLING2(distr, sampler, input_name_1, input_name_2, description) \
MXNET_OPERATOR_REGISTER_SAMPLING(distr, sampler, 2, input_name_1, input_name_2, description) \
.add_argument(input_name_2, "NDArray", "");
MXNET_OPERATOR_REGISTER_SAMPLING2(uniform, UniformSampler, "low", "high",
"uniform distributions on the interval [low,high)")
MXNET_OPERATOR_REGISTER_SAMPLING2(normal, NormalSampler, "mu", "sigma",
"normal distributions with parameters mu and sigma")
MXNET_OPERATOR_REGISTER_SAMPLING2(gamma, GammaSampler, "alpha", "beta",
"gamma distributions with parameters alpha and beta")
MXNET_OPERATOR_REGISTER_SAMPLING1(exponential, ExponentialSampler, "lam",
"exponential distributions with parameters lambda")
MXNET_OPERATOR_REGISTER_SAMPLING1(poisson, PoissonSampler, "lam",
"Poisson distributions with parameters lambda")
MXNET_OPERATOR_REGISTER_SAMPLING2(negative_binomial, NegativeBinomialSampler, "k", "p",
"negative binomial distributions with parameters k (failure limit) and p (failure probability)")
MXNET_OPERATOR_REGISTER_SAMPLING2(generalized_negative_binomial,
GeneralizedNegativeBinomialSampler, "mu", "alpha",
"generalized negative binomial distributions with parameters mu (mean)"
" and alpha (over dispersion)")
} // namespace op
} // namespace mxnet