blob: ff413bd6f286aa9b73c8bb468f2a3777b8e10d7b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2017 by Contributors
* \file multisample_op.h
* \brief Function definitions of operators for sampling from multiple distributions
*/
#ifndef MXNET_OPERATOR_RANDOM_MULTISAMPLE_OP_H_
#define MXNET_OPERATOR_RANDOM_MULTISAMPLE_OP_H_
#include <mxnet/operator_util.h>
#include <vector>
#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "./sampler.h"
namespace mxnet {
namespace op {
struct MultiSampleParam : public dmlc::Parameter<MultiSampleParam> {
mxnet::TShape shape;
int dtype;
DMLC_DECLARE_PARAMETER(MultiSampleParam) {
DMLC_DECLARE_FIELD(shape)
.set_default(mxnet::TShape(0, 1))
.describe("Shape to be sampled from each random distribution.");
DMLC_DECLARE_FIELD(dtype)
.add_enum("None", -1)
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.set_default(-1)
.describe("DType of the output in case this can't be inferred. "
"Defaults to float32 if not defined (dtype=None).");
}
};
inline bool MultiSampleOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_GT(in_attrs->size(), 0)
<< "sampling operator takes 1 or 2 arguments (" << in_attrs->size() << " given)";
CHECK_LT(in_attrs->size(), 3)
<< "sampling operator takes 1 or 2 arguments (" << in_attrs->size() << " given)";
CHECK_EQ(out_attrs->size(), 1);
// Get shape to be sampled for each parameter set.
const MultiSampleParam& param = nnvm::get<MultiSampleParam>(attrs.parsed);
mxnet::TShape sshape = param.shape;
for (int i = 0; i < sshape.ndim(); ++i) {
CHECK_GE(sshape[i], 0) << "shape parameter must be non-negative within each dimension";
}
// Examine output shape whether it is already defined.
mxnet::TShape tshape((*out_attrs)[0]);
// The illegal case of tshape.ndim() <= sshape.ndim() will
// automatically crash when we back-propagate from inputs to outputs.
if (tshape.ndim() > sshape.ndim()) {
// Promote down by removing last dimensions which represent the samples.
tshape = mxnet::TShape(tshape.begin(), tshape.begin()+(tshape.ndim()-sshape.ndim()));
}
// Shape assignment/checking for inputs.
for (const auto& in_attr : *in_attrs) {
if ( !shape_assign(&tshape, in_attr)) return false;
}
for (size_t i = 0; i < in_attrs->size(); ++i) {
SHAPE_ASSIGN_CHECK(*in_attrs, i, tshape);
}
if (tshape.ndim() > 0) {
// Shape assignment/check for propagation from inputs to output.
std::vector<index_t> cshape(tshape.begin(), tshape.end());
cshape.insert(cshape.end(), sshape.begin(), sshape.end());
mxnet::TShape oshape(cshape.begin(), cshape.end());
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
}
return true;
}
inline bool MultiSampleOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_GT(in_attrs->size(), 0)
<< "sampling operator takes 1 or 2 arguments (" << in_attrs->size() << " given)";
CHECK_LT(in_attrs->size(), 3)
<< "sampling operator takes 1 or 2 arguments (" << in_attrs->size() << " given)";
CHECK_EQ(out_attrs->size(), 1);
// All inputs must have same type.
int dtype = -1;
for (int in_attr : *in_attrs) {
if (!type_assign(&dtype, in_attr)) return false;
}
for (size_t i = 0; i < in_attrs->size(); ++i) {
TYPE_ASSIGN_CHECK(*in_attrs, i, dtype);
}
if (-1 == dtype) return false;
// The output may have a different type so we can't infer from inputs.
const MultiSampleParam& param = nnvm::get<MultiSampleParam>(attrs.parsed);
dtype = (*out_attrs)[0];
if (dtype != -1) {
if (param.dtype != -1) {
// dtype given in args, check that it matches the output type
CHECK_EQ(dtype, param.dtype) << "Inferred output type does not match requested type: "
<< dtype << " vs " << param.dtype;
}
} else {
// Output type can't be inferred. Use type in args or default.
dtype = (param.dtype == -1 ? mshadow::kFloat32 : param.dtype);
}
bool dtype_ok = (dtype == mshadow::kFloat16) || (dtype == mshadow::kFloat32) ||
(dtype == mshadow::kFloat64);
CHECK_EQ(dtype_ok, true) << "Output type must be float16, float32, or float64: dtype is "
<< dtype<< " vs " << mshadow::kFloat16 << " or " << mshadow::kFloat32 << " or "
<< mshadow::kFloat64;
TYPE_ASSIGN_CHECK(*out_attrs, 0, dtype);
return true;
}
using namespace mxnet::common::random;
template<typename xpu, typename IType, typename OType, typename Sampler, int inum>
struct SamplerCaller;
template<typename xpu, typename IType, typename OType, typename Sampler>
struct SamplerCaller<xpu, IType, OType, Sampler, 1> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
RandGenerator<xpu, OType> *pgen,
mshadow::Stream<xpu> *s) {
Sampler sampler;
sampler.Sample(inputs[0].FlatTo1D<xpu, IType>(s),
outputs[0].FlatTo1D<xpu, OType>(s),
pgen, s);
}
};
template<typename xpu, typename IType, typename OType, typename Sampler>
struct SamplerCaller<xpu, IType, OType, Sampler, 2> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
RandGenerator<xpu, OType> *pgen,
mshadow::Stream<xpu> *s) {
Sampler sampler;
sampler.Sample(inputs[0].FlatTo1D<xpu, IType>(s),
inputs[1].FlatTo1D<xpu, IType>(s),
outputs[0].FlatTo1D<xpu, OType>(s),
pgen, s);
}
};
template<typename xpu, typename Sampler, int inum>
void MultiSampleOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), inum);
CHECK_EQ(outputs.size(), 1);
if (inputs[0].Size() == 0) {
return;
}
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
RandGenerator<xpu, OType> *pgen = ctx.requested[0].get_parallel_random<xpu, OType>();
SamplerCaller<xpu, IType, OType, Sampler, inum>::op(inputs, outputs, pgen, s);
});
});
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_RANDOM_MULTISAMPLE_OP_H_