blob: e88b95a8bdd60c692eecd304fcc176efa06a9343 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file sample_op.h
* \brief Elementary unique sampling operators
*/
#ifndef MXNET_OPERATOR_RANDOM_UNIQUE_SAMPLE_OP_H_
#define MXNET_OPERATOR_RANDOM_UNIQUE_SAMPLE_OP_H_
#include <mxnet/operator_util.h>
#include <mshadow/base.h>
#include <string>
#include <vector>
#include <unordered_set>
#include <algorithm>
#include <cmath>
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "./sampler.h"
namespace mxnet {
namespace op {
struct SampleUniqueZifpianParam : public dmlc::Parameter<SampleUniqueZifpianParam> {
int range_max;
mxnet::TShape shape;
DMLC_DECLARE_PARAMETER(SampleUniqueZifpianParam) {
DMLC_DECLARE_FIELD(range_max)
.describe("The number of possible classes.");
DMLC_DECLARE_FIELD(shape)
.set_default(mxnet::TShape())
.describe("2-D shape of the output, where shape[0] is the batch size, and shape[1] "
"is the number of candidates to sample for each batch.");
}
};
template<typename ParamType>
inline bool SampleUniqueShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 2U);
// output shape is known
if ((*out_attrs)[0].ndim() == 2 && !mxnet::ndim_is_known(param.shape)) {
SHAPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::Shape1((*out_attrs)[0][0]));
return true;
}
CHECK_EQ(param.shape.ndim(), 2U);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
SHAPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::Shape1(param.shape[0]));
return true;
}
template<typename ParamType>
inline bool SampleUniqueType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 2U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kInt64);
return true;
}
inline std::vector<ResourceRequest> UniqueSampleResource(const NodeAttrs& attrs) {
return {ResourceRequest::kParallelRandom};
}
/*!
* \brief Launch a generic kernel with parallel unique random generator
* \tparam gen random generator
* \tparam batch_size the batch size
* \tparam num_sampled the number of unique samples per batch
* \tparam Args Varargs type to eventually pass to the OP::Map() function
*/
template<typename GType, typename DType, typename OP, typename ...Args>
inline static void LaunchUniqueRNG(mshadow::Stream<cpu> *s,
common::random::RandGenerator<cpu, GType> *gen,
const int batch_size, const size_t num_sampled,
std::vector<std::unordered_set<DType>> *results,
Args... args) {
// minimal check to avoid division by zero, below.
// if `N` is zero the map operation is a no-op in any case.
if (batch_size <= 0 || num_sampled <= 0) return;
const int nthread = std::min(batch_size, RandGenerator<cpu>::kNumRandomStates);
const int step = (batch_size + nthread - 1) / nthread;
Kernel<OP, cpu>::Launch(s, nthread, *gen, batch_size, num_sampled, results, step, args...);
}
struct UniqueSampleUniformKernel {
template<typename GType, typename DType>
MSHADOW_XINLINE static void Map(int tid, RandGenerator<cpu, GType> gen,
const int batch_size, const size_t num_sampled,
std::vector<std::unordered_set<DType>> *results,
const int step, const GType log_range_max,
DType *samples, DType *num_tries) {
const int begin = tid * step;
const int end = (tid + 1) * step;
typename RandGenerator<cpu, GType>::Impl generator(&gen, tid);
for (int i = begin; i < end && i < batch_size; i++) {
auto &result = results->at(i);
const int base = i * num_sampled;
DType tries = 0;
while (result.size() != num_sampled) {
const double x = generator.uniform();
const DType value = static_cast<DType>(lround(exp(x * log_range_max)) - 1);
// sampling without replacement
if (result.find(value) == result.end()) {
samples[base + result.size()] = value;
result.emplace(value);
}
tries += 1;
}
num_tries[i] = tries;
}
}
};
inline void SampleUniqueZifpian(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using DType = int64_t;
using GType = double;
const SampleUniqueZifpianParam& param = nnvm::get<SampleUniqueZifpianParam>(attrs.parsed);
const int batch_size = param.shape[0];
const size_t num_sampled = static_cast<size_t>(param.shape[1]);
const double log_range_max = log(param.range_max);
CHECK_EQ(outputs.size(), 2U);
CHECK_LE(num_sampled, param.range_max)
<< "Number of samples cannot exceed the number of possible classes";
// rand generator resource and result sets
RandGenerator<cpu, GType> *pgen = ctx.requested[0].get_parallel_random<cpu, GType>();
std::vector<std::unordered_set<DType>> results(batch_size);
for (int i = 0; i < batch_size; i++) {
results[i].reserve(num_sampled);
}
DType *num_tries = outputs[1].dptr<DType>();
DType *samples = outputs[0].dptr<DType>();
Stream<cpu> *s = ctx.get_stream<cpu>();
LaunchUniqueRNG<GType, DType, UniqueSampleUniformKernel>(s, pgen, batch_size, num_sampled,
&results, log_range_max, samples,
num_tries);
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_RANDOM_UNIQUE_SAMPLE_OP_H_