blob: 3346ae4b9b324ddd79ead15d4433319414551db5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file sample_multinomial_op.h
* \brief Operator for sampling from multinomial distributions
*/
#ifndef MXNET_OPERATOR_RANDOM_SAMPLE_MULTINOMIAL_OP_H_
#define MXNET_OPERATOR_RANDOM_SAMPLE_MULTINOMIAL_OP_H_
#include <mxnet/operator_util.h>
#include <vector>
#include <string>
#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "./sampler.h"
namespace mxnet {
namespace op {
struct SampleCategoricalParam : public dmlc::Parameter<SampleCategoricalParam> {
mxnet::TShape shape;
bool get_prob;
int dtype;
DMLC_DECLARE_PARAMETER(SampleCategoricalParam) {
DMLC_DECLARE_FIELD(shape)
.set_default(mxnet::TShape(0, 1))
.describe("Shape to be sampled from each random distribution.");
DMLC_DECLARE_FIELD(get_prob).set_default(false).describe(
"Whether to also return the log probability of sampled "
"result. This is usually used for differentiating through "
"stochastic variables, e.g. in reinforcement learning.");
DMLC_DECLARE_FIELD(dtype)
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.set_default(mshadow::kInt32)
.describe("DType of the output in case this can't be inferred.");
}
};
struct SampleMultinomialParam : public dmlc::Parameter<SampleMultinomialParam> {
mxnet::TShape shape;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(SampleMultinomialParam) {
DMLC_DECLARE_FIELD(shape)
.set_default(mxnet::TShape(0, 1))
.describe("Shape to be sampled from each random distribution.");
DMLC_DECLARE_FIELD(ctx).set_default("").describe(
"Context of output, in format [cpu|gpu|cpu_pinned](n)."
" Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype)
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.set_default(mshadow::kInt32)
.describe("DType of the output in case this can't be inferred.");
}
};
inline bool SampleCategoricalOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
const SampleCategoricalParam& param = nnvm::get<SampleCategoricalParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), param.get_prob ? 2U : 1U);
const mxnet::TShape& ishape = (*in_attrs)[0];
if (!ndim_is_known(ishape))
return false;
if (ishape.ndim() == 1) {
if (param.shape.ndim() > 0) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
if (param.get_prob)
SHAPE_ASSIGN_CHECK(*out_attrs, 1, param.shape);
} else {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, 1));
if (param.get_prob)
SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape(1, 1));
}
return true;
}
mxnet::TShape oshape(ishape.ndim() - 1 + param.shape.ndim(), -1);
for (int i = 0; i < ishape.ndim() - 1; ++i) {
oshape[i] = ishape[i];
}
for (int i = 0; i < param.shape.ndim(); ++i) {
oshape[i + ishape.ndim() - 1] = param.shape[i];
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
if (param.get_prob)
SHAPE_ASSIGN_CHECK(*out_attrs, 1, oshape);
for (const auto& out_shape : *out_attrs) {
if (!shape_is_known(out_shape))
return false;
}
return true;
}
inline bool SampleCategoricalOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
const SampleCategoricalParam& param = nnvm::get<SampleCategoricalParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), param.get_prob ? 2U : 1U);
int itype = (*in_attrs)[0];
if (itype == -1)
return false;
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype);
if (param.get_prob) {
TYPE_ASSIGN_CHECK(*out_attrs, 1, itype);
}
return true;
}
inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const mxnet::TShape& n_shape = (*in_attrs)[0];
const mxnet::TShape& p_shape = (*in_attrs)[1];
if (!ndim_is_known(n_shape) || !ndim_is_known(p_shape) || n_shape.ndim() + 1 != p_shape.ndim())
return false;
mxnet::TShape oshape(p_shape.ndim() + param.shape.ndim(), -1);
for (int i = 0; i < p_shape.ndim() - 1; ++i) {
if (n_shape[i] != p_shape[i])
return false;
oshape[i] = p_shape[i];
}
for (int i = 0; i < param.shape.ndim(); ++i) {
oshape[i + p_shape.ndim() - 1] = param.shape[i];
}
oshape[p_shape.ndim() + param.shape.ndim() - 1] = p_shape[p_shape.ndim() - 1];
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return true;
}
inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
int dtype = -1;
int dtype_n = (*in_attrs)[0];
int dtype_out = (*out_attrs)[0];
if (dtype_out != -1) {
dtype = dtype_out;
if (param.dtype != -1) {
CHECK_EQ(dtype_out, param.dtype)
<< "Output type does not match requested type: " << dtype_out << " vs " << param.dtype;
}
} else {
if (dtype_n != -1) {
dtype = dtype_n;
} else {
dtype = mxnet::common::GetDefaultDtype();
}
}
TYPE_ASSIGN_CHECK(*out_attrs, 0, dtype);
return true;
}
struct SampleCategoricalKernel {
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i,
index_t K,
index_t M,
DType* dist,
float* uniform,
float* cum_table,
IType* out,
DType* prob) {
double acc = 0.0;
// CDF table
for (index_t c = 0; c < K; ++c) {
acc += dist[i * K + c];
cum_table[i * K + c] = static_cast<float>(acc);
}
for (index_t j = 0; j < M; ++j) {
index_t left = 0, right = K;
index_t middle = left + (right - left) / 2;
DType loc = static_cast<DType>(uniform[i * M + j]);
while (right - left > 0) {
middle = left + (right - left) / 2;
DType cum_prob = cum_table[i * K + middle];
if (cum_prob < loc) {
left = middle + 1;
} else {
right = middle;
}
}
out[i * M + j] = static_cast<IType>(left);
if (prob != nullptr)
prob[i * M + j] = logf(dist[i * K + left]);
}
}
};
template <typename xpu, typename NType, typename PType, typename OType>
MSHADOW_XINLINE void SampleMultinomial(NType N,
const PType* p,
OType* out,
index_t K,
typename RandGenerator<xpu, float>::Impl* gen) {
PType remaining_p = 1.0;
NType dN = N;
int j;
for (j = 0; j < K - 1; ++j) {
out[j] = SampleBinomial<xpu, PType, OType>(static_cast<PType>(dN), p[j] / remaining_p, gen);
dN = dN - out[j];
if (dN <= 0)
break;
remaining_p -= p[j];
}
for (j = j + 1; j < K; ++j)
out[j] = 0;
if (dN > 0)
out[K - 1] = dN;
}
template <typename xpu>
struct SampleMultinomialKernel {
template <typename NType, typename PType, typename OType>
MSHADOW_XINLINE static void Map(index_t id,
RandGenerator<xpu, float> gen,
const index_t N,
const index_t step,
index_t nParm,
index_t nSample,
index_t K,
const NType* n,
const PType* p,
OType* out) {
RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
index_t nBatch(1 + (nSample - 1) / nParm);
SampleMultinomial<xpu, NType, PType, OType>(
n[i / nBatch], &p[(i / nBatch) * K], &out[i * K], K, &genImpl);
})
}
};
template <typename xpu>
void SampleCategoricalForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
const SampleCategoricalParam& param = nnvm::get<SampleCategoricalParam>(attrs.parsed);
index_t K = inputs[0].shape_[inputs[0].ndim() - 1];
index_t N = inputs[0].Size() / K;
index_t M = outputs[0].Size() / N;
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Random<xpu, float>* prnd = ctx.requested[0].get_random<xpu, float>(s);
Tensor<xpu, 1, float> workspace =
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(N * M + N * K), s);
Tensor<xpu, 1, float> uniform(workspace.dptr_, Shape1(N * M));
prnd->SampleUniform(&uniform, 0, 1);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, {
Kernel<SampleCategoricalKernel, xpu>::Launch(
s,
N,
K,
M,
inputs[0].dptr<DType>(),
uniform.dptr_,
workspace.dptr_ + N * M,
outputs[0].dptr<IType>(),
param.get_prob ? outputs[1].dptr<DType>() : nullptr);
});
});
}
template <typename xpu>
static inline void multinomial_op(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const OpReqType& req,
TBlob* num,
TBlob* prob,
TBlob* outputs) {
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(
num[0].type_flag_,
NType,
{MSHADOW_REAL_TYPE_SWITCH(
prob[0].type_flag_, PType, {MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
RandGenerator<xpu, OType>* pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
RandGenerator<xpu, float>* gen = reinterpret_cast<RandGenerator<xpu, float>*>(pgen);
Tensor<xpu, 1, OType> out = outputs->FlatTo1D<xpu, OType>(s);
Tensor<xpu, 1, NType> n = num->FlatTo1D<xpu, NType>(s);
Tensor<xpu, 1, PType> p = prob->FlatTo1D<xpu, PType>(s);
index_t K = prob->shape_[prob->ndim() - 1];
LaunchRNG<SampleMultinomialKernel<xpu>, xpu>(s,
gen,
out.size(0) / K,
n.size(0),
out.size(0) / K,
K,
n.dptr_,
p.dptr_,
out.dptr_);
})})});
}
template <typename xpu>
void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
TBlob num = inputs[0];
TBlob prob = inputs[1];
TBlob out = outputs[0];
multinomial_op<xpu>(attrs, ctx, req[0], &num, &prob, &out);
}
template <typename kernel, typename xpu>
void SampleCategoricalBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
if (req[0] == kNullOp)
return;
index_t K = outputs[0].shape_[outputs[0].ndim() - 1];
index_t N = outputs[0].Size() / K;
index_t M = inputs[0].Size() / N;
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
if (req[0] != kAddTo) {
Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s);
out = 0;
}
MSHADOW_TYPE_SWITCH(inputs[2].type_flag_, IType, {
Kernel<kernel, xpu>::Launch(s,
N,
K,
M,
inputs[0].dptr<DType>(),
inputs[1].dptr<DType>(),
inputs[2].dptr<IType>(),
outputs[0].dptr<DType>());
});
});
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_RANDOM_SAMPLE_MULTINOMIAL_OP_H_