| /*! |
| * Copyright (c) 2017 by Contributors |
| * \file multisample_op.cc |
| * \brief CPU-implementation of multi-sampling operators |
| */ |
| #include "./multisample_op.h" |
| |
| namespace mxnet { |
| namespace op { |
| |
| struct UniformSampler { |
| template<typename DType> |
| struct Sampler{ |
| std::mt19937 rnd_; |
| // Ensure that half_t is handled correctly. |
| typedef typename std::conditional<std::is_floating_point<DType>::value, |
| DType, double>::type FType; |
| typedef typename std::conditional<std::is_integral<DType>::value, |
| std::uniform_int_distribution<DType>, |
| std::uniform_real_distribution<FType>>::type GType; |
| GType gen_; |
| template<typename PType> |
| Sampler(PType a, PType b, int seed): rnd_(seed), gen_(a, b) { } |
| MSHADOW_XINLINE DType operator()() { return gen_(rnd_); } |
| }; |
| }; |
| |
| struct NormalSampler { |
| template<typename DType> |
| struct Sampler{ |
| std::mt19937 rnd_; |
| typedef typename std::conditional<std::is_floating_point<DType>::value, |
| DType, double>::type GType; |
| std::normal_distribution<GType> gen_; |
| template<typename PType> |
| Sampler(PType mu, PType sigma, int seed): rnd_(seed), gen_(mu, sigma) |
| { CHECK_EQ(std::is_floating_point<DType>::value, true) |
| << "Normal distribution must have floating point target type"; }; |
| MSHADOW_XINLINE DType operator()() { return gen_(rnd_); } |
| }; |
| }; |
| |
| struct GammaSampler { |
| template<typename DType> |
| struct Sampler{ |
| std::mt19937 rnd_; |
| // Avoid problems with static check during compilation for integral types. |
| typedef typename std::conditional<std::is_floating_point<DType>::value, |
| DType, double>::type GType; |
| std::gamma_distribution<GType> gen_; |
| template<typename PType> |
| Sampler(PType alpha, PType beta, int seed): rnd_(seed), gen_(alpha, beta) |
| { CHECK_EQ(std::is_floating_point<DType>::value, true) |
| << "Gamma distribution must have floating point target type"; }; |
| MSHADOW_XINLINE DType operator()() { return gen_(rnd_); } |
| }; |
| }; |
| |
| struct ExponentialSampler { |
| template<typename DType> |
| struct Sampler{ |
| std::mt19937 rnd_; |
| // Avoid problems with static check during compilation for integral types. |
| typedef typename std::conditional<std::is_floating_point<DType>::value, |
| DType, double>::type GType; |
| std::exponential_distribution<GType> gen_; |
| template<typename PType> |
| Sampler(PType lambda, PType , int seed): rnd_(seed), gen_(lambda) |
| { CHECK_EQ(std::is_floating_point<DType>::value, true) |
| << "Exponential distribution must have floating point target type"; }; |
| MSHADOW_XINLINE DType operator()() { return gen_(rnd_); } |
| }; |
| }; |
| |
| struct PoissonSampler { |
| template<typename DType> |
| struct Sampler{ |
| std::mt19937 rnd_; |
| // Allow sampling of a Poisson distribution also to output floating point types. |
| typedef typename std::conditional<std::is_integral<DType>::value, |
| DType, int>::type GType; |
| std::poisson_distribution<GType> gen_; |
| template<typename PType> |
| Sampler(PType lambda, PType , int seed): rnd_(seed), gen_(lambda) { } |
| MSHADOW_XINLINE DType operator()() { return static_cast<DType>(gen_(rnd_)); } |
| }; |
| }; |
| |
| // Negative binomial distribution as defined in C++ standard library |
| struct NegativeBinomialSampler { |
| template<typename DType> |
| struct Sampler{ |
| std::mt19937 rnd_; |
| // Allow sampling of a negative binomial distribution also to output floating point types. |
| typedef typename std::conditional<std::is_integral<DType>::value, DType, int>::type GType; |
| std::negative_binomial_distribution<GType> gen_; |
| template<typename PType> |
| Sampler(PType k, PType p, int seed): rnd_(seed), gen_(k, p) {} |
| MSHADOW_XINLINE DType operator()() { return static_cast<DType>(gen_(rnd_)); } |
| }; |
| }; |
| |
| // Generalized form of the negative binomial distribution which is generated by |
| // a poisson-gamma mixture: X ~ NegBin(mu, alpha) corresponds to |
| // X ~ Poisson(Gamma(1/alpha,mu*alpha)) |
| struct GeneralizedNegativeBinomialSampler { |
| template<typename DType> |
| struct Sampler { |
| // We allow the boundary case where the negative binomial equals the Poisson distribution |
| bool poisson_; |
| double mu_; |
| std::mt19937 rnd_; |
| // Realize the negative binomial by a Poisson distribution over a gamma distributed mean. |
| std::gamma_distribution<> gen_; |
| template<typename PType> |
| Sampler(PType mu, PType alpha, int seed): poisson_(alpha == 0.0), mu_(mu), rnd_(seed), |
| gen_((alpha == PType(0) ? PType(1) : PType(1)/alpha), mu*alpha) {} |
| // Allow sampling of a Poisson distribution also to output floating point types. |
| typedef typename std::conditional<std::is_integral<DType>::value, DType, int>::type GType; |
| MSHADOW_XINLINE DType operator()() { return static_cast<DType>( |
| std::poisson_distribution<GType>(poisson_ ? mu_ : gen_(rnd_))(rnd_)); } |
| }; |
| }; |
| |
| |
| DMLC_REGISTER_PARAMETER(MultiSampleParam); |
| |
| #define MXNET_OPERATOR_REGISTER_SAMPLING(distr, sampler, num_inputs, \ |
| input_name_1, input_name_2, description) \ |
| NNVM_REGISTER_OP(sample_##distr) \ |
| .MXNET_DESCRIBE("Multi-sampling from " description "." \ |
| " The parameters of the distributions are provided as input tensor(s)." \ |
| " Let \"[s]\" be the shape of the input tensor(s), \"n\" be the dimension of [s], \"[t]\"" \ |
| " be the shape specified as the parameter of the operator, and \"m\" be the dimension" \ |
| " of [t]. Then the output will be a (n+m)-dimensional tensor with shape [s]x[t]. For any" \ |
| " valid n-dimensional index \"i\" with respect to the input tensor(s), output[i] will be" \ |
| " an m-dimensional tensor that holds randomly drawn samples from the distribution which" \ |
| " is parameterized by the input values at index i. If the shape parameter of the operator" \ |
| " is not set, then one sample will be drawn per distribution and the output tensor has" \ |
| " the same dimensions as the input tensor(s).") \ |
| .set_num_inputs(num_inputs) \ |
| .set_num_outputs(1) \ |
| .set_attr_parser(ParamParser<MultiSampleParam>) \ |
| .set_attr<nnvm::FListInputNames>("FListInputNames", \ |
| [](const NodeAttrs& attrs) { \ |
| std::vector<std::string> v = {input_name_1, input_name_2}; v.resize(num_inputs); return v; \ |
| }) \ |
| .set_attr<nnvm::FInferShape>("FInferShape", MultiSampleOpShape) \ |
| .set_attr<nnvm::FInferType>("FInferType", MultiSampleOpType) \ |
| .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) { \ |
| return std::vector<ResourceRequest>(1, ResourceRequest::kRandom); \ |
| }) \ |
| .set_attr<FCompute>("FCompute<cpu>", MultiSampleOpForward<cpu, sampler>) \ |
| .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) \ |
| .add_arguments(MultiSampleParam::__FIELDS__()) \ |
| .add_argument(input_name_1, "NDArray", "") |
| |
| #define MXNET_OPERATOR_REGISTER_SAMPLING1(distr, sampler, input_name, description) \ |
| MXNET_OPERATOR_REGISTER_SAMPLING(distr, sampler, 1, input_name, input_name, description); |
| |
| #define MXNET_OPERATOR_REGISTER_SAMPLING2(distr, sampler, input_name_1, input_name_2, description) \ |
| MXNET_OPERATOR_REGISTER_SAMPLING(distr, sampler, 2, input_name_1, input_name_2, description) \ |
| .add_argument(input_name_2, "NDArray", ""); |
| |
| MXNET_OPERATOR_REGISTER_SAMPLING2(uniform, UniformSampler, "low", "high", |
| "uniform distributions on the interval [low,high)") |
| MXNET_OPERATOR_REGISTER_SAMPLING2(normal, NormalSampler, "mu", "sigma", |
| "normal distributions with parameters mu and sigma") |
| MXNET_OPERATOR_REGISTER_SAMPLING2(gamma, GammaSampler, "alpha", "beta", |
| "gamma distributions with parameters alpha and beta") |
| MXNET_OPERATOR_REGISTER_SAMPLING1(exponential, ExponentialSampler, "lam", |
| "exponential distributions with parameters lambda") |
| MXNET_OPERATOR_REGISTER_SAMPLING1(poisson, PoissonSampler, "lam", |
| "Poisson distributions with parameters lambda") |
| MXNET_OPERATOR_REGISTER_SAMPLING2(negative_binomial, NegativeBinomialSampler, "k", "p", |
| "negative binomial distributions with parameters k (failure limit) and p (failure probability)") |
| MXNET_OPERATOR_REGISTER_SAMPLING2(generalized_negative_binomial, |
| GeneralizedNegativeBinomialSampler, "mu", "alpha", |
| "generalized negative binomial distributions with parameters mu (mean)" |
| " and alpha (over dispersion)") |
| |
| } // namespace op |
| } // namespace mxnet |