blob: e08a682d94c3947a6b7d6631b971b011e30fcf5f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file init_op.h
* \brief Function definition of initialization op
*/
#ifndef MXNET_OPERATOR_TENSOR_INIT_OP_H_
#define MXNET_OPERATOR_TENSOR_INIT_OP_H_
#include <mxnet/base.h>
#include <mxnet/operator_util.h>
#include <mxnet/op_attr_types.h>
#include <dmlc/parameter.h>
#include <dmlc/optional.h>
#include <vector>
#include <string>
#include <limits>
#include "../elemwise_op_common.h"
#include "../mxnet_op.h"
namespace mxnet {
namespace op {
struct InitOpParam : public dmlc::Parameter<InitOpParam> {
TShape shape;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(InitOpParam) {
DMLC_DECLARE_FIELD(shape)
.set_default(TShape())
.describe("The shape of the output");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
.describe("Target data type.");
}
};
struct RangeParam : public dmlc::Parameter<RangeParam> {
double start;
dmlc::optional<double> stop;
double step;
int repeat;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(RangeParam) {
DMLC_DECLARE_FIELD(start)
.describe("Start of interval. The interval includes this value. The default start value is 0.");
DMLC_DECLARE_FIELD(stop)
.set_default(dmlc::optional<double>())
.describe("End of interval. The interval does not include this value,"
" except in some cases where step is not an integer and"
" floating point round-off affects the length of out.");
DMLC_DECLARE_FIELD(step)
.set_default(1)
.describe("Spacing between values.");
DMLC_DECLARE_FIELD(repeat)
.set_default(1)
.describe("The repeating time of all elements."
" E.g repeat=3, the element a will be repeated three times --> a, a, a.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
.describe("Target data type.");
}
};
/*! \brief Parse keyword arguments as PType arguments and save to parsed */
inline void RangeParamParser(nnvm::NodeAttrs* attrs) {
RangeParam param;
param.Init(attrs->dict);
if (!static_cast<bool>(param.stop)) {
param.stop = param.start;
param.start = 0;
}
attrs->parsed = std::move(param);
}
template<typename ParamType>
inline bool InitShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
if ((*out_attrs)[0].ndim() != 0 && param.shape.ndim() == 0) return true;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
return true;
}
template<typename ParamType>
inline bool InitType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype);
return true;
}
template<typename ParamType, bool rsp, bool csr>
inline bool InitStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
auto &out_stype = out_attrs->at(0);
bool dispatched = false;
type_assign(&out_stype, kDefaultStorage);
if (!dispatched && out_stype == kDefaultStorage) {
// default
dispatched = storage_type_assign(out_attrs, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched && rsp && out_stype == kRowSparseStorage) {
// rsp
dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched && csr && out_stype == kCSRStorage) {
// csr
dispatched = storage_type_assign(out_attrs, kCSRStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatch_fallback(out_attrs, dispatch_mode);
LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs);
}
return true;
}
template<typename xpu, int value>
void FillCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s);
ASSIGN_DISPATCH(out, req[0], scalar<DType>(value));
});
}
// Fill in the indices and values of a RowSparse NDArray to represent a zeros NDArray,
// instead of the usual compact representation.
template<typename xpu>
inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
using namespace rowsparse;
using namespace mshadow::expr;
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(dst->storage_type(), kRowSparseStorage);
MSHADOW_REAL_TYPE_SWITCH(dst->dtype(), DType, {
MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
auto num_rows = dst->shape()[0];
dst->CheckAndAlloc({Shape1(num_rows)});
auto idx = dst->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
auto val = dst->data();
Kernel<set_zero, xpu>::Launch(s, val.Size(), val.dptr<DType>());
ASSIGN_DISPATCH(idx, kWriteTo, range<IType>(0, num_rows, 1, 1));
});
});
}
struct PopulateFullIdxRspKernel {
template<typename IType>
MSHADOW_XINLINE static void Map(int i, IType* out) {
KERNEL_ASSIGN(out[i], kWriteTo, i);
}
};
// Fill full indices NDArray with zeros by updating the aux shape.
template<typename xpu>
void PopulateFullIdxRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
using namespace rowsparse;
CHECK_EQ(dst->storage_type(), kRowSparseStorage);
nnvm::dim_t nnr = dst->shape()[0];
dst->CheckAndAllocAuxData(kIdx, mshadow::Shape1(nnr));
MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
IType* idx = dst->aux_data(kIdx).dptr<IType>();
mxnet_op::Kernel<PopulateFullIdxRspKernel, xpu>::Launch(s, nnr, idx);
});
}
/*!
* \brief Fill a rsp NDArray with zeros by updating the aux shape.
* \tparam xpu - cpu or gpu
* \param s - The device stream
* \param dst - NDArray which is to be set to "all zeroes"
*/
template<typename xpu>
void FillZerosRspImpl(mshadow::Stream<xpu> *, const NDArray& dst) {
if (dst.storage_initialized()) {
// reset the shapes if it's not zeros (set_aux_shape() will set storage_shape to zero as well)
dst.set_aux_shape(rowsparse::kIdx, TShape(mshadow::Shape1(0)));
}
}
/*!
* \brief Fill a CSR NDArray with zeros by updating the aux shape
* \param s - The device stream
* \param dst - NDArray which is to be set to "all zeroes"
*/
inline void FillZerosCsrImpl(mshadow::Stream<mshadow::cpu> *s, const NDArray& dst) {
dst.set_aux_shape(csr::kIdx, mshadow::Shape1(0));
dst.CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(dst.shape()[0] + 1));
TBlob indptr_data = dst.aux_data(csr::kIndPtr);
MSHADOW_IDX_TYPE_SWITCH(dst.aux_type(csr::kIndPtr), IType, {
mxnet_op::Kernel<mxnet_op::set_zero, mshadow::cpu>::Launch(
s, indptr_data.Size(), indptr_data.dptr<IType>());
});
}
void FillZerosCsrImpl(mshadow::Stream<mshadow::gpu> *s, const NDArray& dst);
/*!
* \brief Fill an NDArray with zeros
* \tparam xpu - cpu or gpu
* \param attrs - node attributes (unused)
* \param ctx - Device context
* \param inputs - NDArray inputs (unused)
* \param req - Request type (i.e. kWrite, kNullOp, etc.)
* \param outputs - Array which contains at position zero (0) the array to be set to zeros
*/
template<typename xpu>
void FillComputeZerosEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(outputs.size(), 1);
auto stype = outputs[0].storage_type();
if (req[0] == kNullOp) return;
CHECK_EQ(req[0], kWriteTo) << "kWriteTo is expected for FillComputeZerosEx";
if (stype == kRowSparseStorage) {
FillZerosRspImpl(s, outputs[0]);
} else if (stype == kCSRStorage) {
FillZerosCsrImpl(s, outputs[0]);
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
}
struct range_fwd {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, int repeat, DType start, DType step,
int req, DType* out) {
KERNEL_ASSIGN(out[i], req, start + (i/repeat) * step);
}
};
template<typename xpu>
void RangeCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Kernel<range_fwd, xpu>::Launch(s, outputs[0].Size(),
static_cast<int>(param.repeat), static_cast<DType>(param.start),
static_cast<DType>(param.step), req[0], outputs[0].dptr<DType>());
});
}
inline bool RangeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_NE(param.step, 0)
<< "Range does not support step=0, received " << param.step;
CHECK(param.repeat > 0)
<< "Range only supports repeat > 0, received " << param.repeat;
if (param.step > 0) {
CHECK(param.start < param.stop.value())
<< "Invalid range (start, stop, step) = "
<< "(" << param.start << "," << param.stop.value() << "," << param.step << ")";
} else {
CHECK(param.start > param.stop.value())
<< "Invalid range (start, stop, step)= "
<< "(" << param.start << "," << param.stop.value() << "," << param.step << ")";
}
MSHADOW_TYPE_SWITCH(param.dtype, DType, {
double out_size = std::ceil((param.stop.value() - param.start) / param.step)
* param.repeat;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({static_cast<nnvm::dim_t>(out_size)}));
});
return true;
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_INIT_OP_H_