blob: f3c405d7103c8451b6a2da3cbaf5c0de9ccf49fd [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file init_op.h
* \brief Function definition of initialization op
*/
#ifndef MXNET_OPERATOR_TENSOR_INIT_OP_H_
#define MXNET_OPERATOR_TENSOR_INIT_OP_H_
#include <mxnet/base.h>
#include <mxnet/operator_util.h>
#include <mxnet/op_attr_types.h>
#include <mxnet/imperative.h>
#include <dmlc/parameter.h>
#include <dmlc/optional.h>
#include <vector>
#include <string>
#include <algorithm>
#include <limits>
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
#include "../mxnet_op.h"
#include "../mshadow_op.h"
namespace mxnet {
namespace op {
struct InitOpParam : public dmlc::Parameter<InitOpParam> {
mxnet::TShape shape;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(InitOpParam) {
DMLC_DECLARE_FIELD(shape)
.set_default(mxnet::TShape(0, 1))
.describe("The shape of the output");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
MXNET_ADD_ALL_TYPES
.describe("Target data type.");
}
};
struct InitOpWithoutDTypeParam : public dmlc::Parameter<InitOpWithoutDTypeParam> {
mxnet::TShape shape;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(InitOpWithoutDTypeParam) {
DMLC_DECLARE_FIELD(shape)
.set_default(mxnet::TShape())
.describe("The shape of the output");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype)
.set_default(-1)
.describe("Target data type.");
}
};
struct EyeParam : public dmlc::Parameter<EyeParam> {
nnvm::dim_t N;
nnvm::dim_t M;
nnvm::dim_t k;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(EyeParam) {
DMLC_DECLARE_FIELD(N)
.describe("Number of rows in the output.");
DMLC_DECLARE_FIELD(M)
.set_default(0)
.describe("Number of columns in the output. If 0, defaults to N");
DMLC_DECLARE_FIELD(k)
.set_default(0)
.describe("Index of the diagonal. 0 (the default) refers to the main diagonal."
"A positive value refers to an upper diagonal."
"A negative value to a lower diagonal.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
.add_enum("uint8", mshadow::kUint8)
.add_enum("int8", mshadow::kInt8)
.add_enum("int32", mshadow::kInt32)
.add_enum("int64", mshadow::kInt64)
.describe("Target data type.");
}
};
template<typename ParamType>
inline bool InitEyeShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape2(param.N, param.M > 0 ? param.M : param.N));
return true;
}
template<int req>
struct eye_dns_fill {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
const nnvm::dim_t init_col,
const nnvm::dim_t k,
const nnvm::dim_t num_cols) {
KERNEL_ASSIGN(out_data[(i+init_col-k)*num_cols+i+init_col], req, static_cast<DType>(1));
}
};
struct RangeParam : public dmlc::Parameter<RangeParam> {
double start;
dmlc::optional<double> stop;
double step;
int repeat;
bool infer_range;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(RangeParam) {
DMLC_DECLARE_FIELD(start)
.describe("Start of interval. The interval includes this value. The default start value is 0.");
DMLC_DECLARE_FIELD(stop)
.set_default(dmlc::optional<double>())
.describe("End of interval. The interval does not include this value,"
" except in some cases where step is not an integer and"
" floating point round-off affects the length of out.");
DMLC_DECLARE_FIELD(step)
.set_default(1)
.describe("Spacing between values.");
DMLC_DECLARE_FIELD(repeat)
.set_default(1)
.describe("The repeating time of all elements."
" E.g repeat=3, the element a will be repeated three times --> a, a, a.");
DMLC_DECLARE_FIELD(infer_range)
.set_default(false)
.describe("When set to True, infer the stop position from the start, step, "
"repeat, and output tensor size.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
MXNET_ADD_ALL_TYPES
.describe("Target data type.");
}
};
struct RangeLikeParam : public dmlc::Parameter<RangeLikeParam> {
double start;
double step;
int repeat;
std::string ctx;
dmlc::optional<int> axis;
DMLC_DECLARE_PARAMETER(RangeLikeParam) {
DMLC_DECLARE_FIELD(start)
.set_default(0)
.describe("Start of interval. The interval includes this value. The default start value is 0.");
DMLC_DECLARE_FIELD(step)
.set_default(1)
.describe("Spacing between values.");
DMLC_DECLARE_FIELD(repeat)
.set_default(1)
.describe("The repeating time of all elements."
" E.g repeat=3, the element a will be repeated three times --> a, a, a.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(axis)
.set_default(dmlc::optional<int>())
.describe("Arange elements according to the size of a certain axis of input array."
" The negative numbers are interpreted counting from the backward."
" If not provided, will arange elements according to the input shape.");
}
};
/*! \brief Initialize and fill output with an arbitrary value */
struct InitOpWithScalarParam : dmlc::Parameter<InitOpWithScalarParam> {
mxnet::TShape shape;
std::string ctx;
int dtype;
double value;
DMLC_DECLARE_PARAMETER(InitOpWithScalarParam) {
DMLC_DECLARE_FIELD(shape)
.set_default(mxnet::TShape())
.describe("The shape of the output");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
MXNET_ADD_ALL_TYPES
.describe("Target data type.");
DMLC_DECLARE_FIELD(value)
.describe("Value with which to fill newly created tensor");
}
};
/*! \brief Parse keyword arguments as PType arguments and save to parsed */
inline void RangeParamParser(nnvm::NodeAttrs* attrs) {
RangeParam param;
param.Init(attrs->dict);
if (!static_cast<bool>(param.infer_range) && !static_cast<bool>(param.stop)) {
param.stop = param.start;
param.start = 0;
}
attrs->parsed = std::move(param);
}
struct LinspaceParam : public dmlc::Parameter<LinspaceParam> {
double start;
double stop;
int num;
bool endpoint;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(LinspaceParam) {
DMLC_DECLARE_FIELD(start)
.describe("The starting value of the sequence.");
DMLC_DECLARE_FIELD(stop)
.describe("The ending value of the sequence");
DMLC_DECLARE_FIELD(num)
.describe("Number of samples to generate. Must be non-negative.");
DMLC_DECLARE_FIELD(endpoint)
.set_default(true)
.describe("If True, stop is the last sample. Otherwise, it is not included.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
MXNET_ADD_ALL_TYPES
.describe("Target data type.");
}
};
template<typename ParamType>
inline bool InitShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape param_shape = param.shape;
if (!Imperative::Get()->is_np_shape()) {
common::ConvertToNumpyShape(&param_shape);
}
if (shape_is_known((*out_attrs)[0]) && !shape_is_known(param_shape)) return true;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, param_shape);
return shape_is_known(out_attrs->at(0));
}
template<typename ParamType, int num_in = 0U>
inline bool InitType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
CHECK_EQ(in_attrs->size(), num_in);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype);
return true;
}
template<typename ParamType, bool rsp, bool csr>
inline bool InitStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
auto &out_stype = out_attrs->at(0);
bool dispatched = false;
type_assign(&out_stype, kDefaultStorage);
if (!dispatched && out_stype == kDefaultStorage) {
// default
dispatched = storage_type_assign(out_attrs, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched && rsp && out_stype == kRowSparseStorage) {
// rsp
dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched && csr && out_stype == kCSRStorage) {
// csr
dispatched = storage_type_assign(out_attrs, kCSRStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}
/*!
* \brief General-purpose blob value-filling function
* \tparam xpu cpu or gpu
* \tparam ValueType Data type of supplied value
* \tparam is_integer Whether to optimize for an integer value
* \param s Stream
* \param b The blob to fill with a value
* \param req Request type (kNullOp, kWriteTo, etc)
* \param val The value to use for the filling operation
*/
template <bool is_integer = false, typename ValueType, typename xpu>
void Fill(mshadow::Stream<xpu> *s, const TBlob& b, const OpReqType req, ValueType val) {
// If b is a zero-size tensor, do nothing.
if (b.Size() == 0) return;
if (req != kNullOp) {
const size_t size = b.Size();
if (val == 0) {
if (req != kAddTo) {
if (b.dev_mask() == cpu::kDevMask && size < 50000) {
MSHADOW_TYPE_SWITCH(b.type_flag_, DType, {
memset(b.dptr_, 0, size * sizeof(DType));
});
} else {
// Optimize common use-case of filling with ones
MSHADOW_TYPE_SWITCH(b.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::set_to_int<0>, Req>, xpu>::Launch(
s, b.Size(), b.dptr<DType>());
});
});
}
}
} else if (is_integer && val == 1) {
// Optimize common use-case of filling with ones
MSHADOW_TYPE_SWITCH(b.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::set_one, Req>, xpu>::Launch(
s, b.Size(), b.dptr<DType>());
});
});
} else {
// Generic fill kernel from variable
MSHADOW_TYPE_SWITCH(b.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, xpu>::Launch(
s, b.Size(), b.dptr<DType>(), static_cast<DType>(val));
});
});
}
}
}
/*! \brief Fill output with a scalar integer value */
template<typename xpu, int value>
void FillCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
Fill<true>(ctx.get_stream<xpu>(), outputs[0], req[0], value);
}
/*! \brief Fill output with an arbitrary value */
template<typename xpu>
void InitFillWithScalarCompute(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(inputs.size(), 0);
CHECK_EQ(outputs.size(), 1U);
const auto& param = nnvm::get<InitOpWithScalarParam>(attrs.parsed);
Fill<false>(ctx.get_stream<xpu>(), outputs[0], req[0], param.value);
}
struct PopulateFullIdxRspKernel : public mxnet_op::tunable {
template<typename IType>
MSHADOW_XINLINE static void Map(int i, IType* out) {
KERNEL_ASSIGN(out[i], kWriteTo, i);
}
};
// Fill in the indices and values of a RowSparse NDArray to represent a zeros NDArray,
// instead of the usual compact representation.
template<typename xpu>
inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
using namespace rowsparse;
using namespace mshadow::expr;
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(dst->storage_type(), kRowSparseStorage);
MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
const index_t num_rows = dst->shape()[0];
dst->CheckAndAlloc({Shape1(num_rows)});
Fill<true>(s, dst->data(), kWriteTo, 0);
auto idx = dst->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
Kernel<PopulateFullIdxRspKernel, xpu>::Launch(s, num_rows, idx.dptr_);
});
}
/*!
* \brief Fill a rsp NDArray with zeros by updating the aux shape.
* \tparam xpu - cpu or gpu
* \param s - The device stream
* \param dst - NDArray which is to be set to "all zeroes"
*/
template<typename xpu>
void FillZerosRspImpl(mshadow::Stream<xpu> *, const NDArray& dst) {
CHECK_EQ(dst.storage_type(), kRowSparseStorage) << "dst should be an RSP NDArray";
if (dst.storage_initialized()) {
// reset the shapes if it's not zeros (set_aux_shape() will set storage_shape to zero as well)
dst.set_aux_shape(rowsparse::kIdx, mxnet::TShape(mshadow::Shape1(0)));
}
}
/*!
* \brief Fill a CSR NDArray with zeros by updating the aux shape
* \param s - The device stream
* \param dst - NDArray which is to be set to "all zeroes"
*/
inline void FillZerosCsrImpl(mshadow::Stream<mshadow::cpu> *s, const NDArray& dst) {
CHECK_EQ(dst.storage_type(), kCSRStorage) << "dst is not a CSR NDArray";
dst.set_aux_shape(csr::kIdx, mshadow::Shape1(0));
dst.CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(dst.shape()[0] + 1));
TBlob indptr_data = dst.aux_data(csr::kIndPtr);
Fill<true>(s, dst.aux_data(csr::kIndPtr), kWriteTo, 0);
}
void FillZerosCsrImpl(mshadow::Stream<mshadow::gpu> *s, const NDArray& dst);
/*!
* \brief Fill an NDArray with zeros
* \tparam xpu - cpu or gpu
* \param attrs - node attributes (unused)
* \param ctx - Device context
* \param inputs - NDArray inputs (unused)
* \param req - Request type (i.e. kWrite, kNullOp, etc.)
* \param outputs - Array which contains at position zero (0) the array to be set to zeros
*/
template<typename xpu>
void FillComputeZerosEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(outputs.size(), 1);
auto stype = outputs[0].storage_type();
// x + 0 == x
if (req[0] == kNullOp || req[0] == kAddTo) return;
if (stype == kRowSparseStorage) {
FillZerosRspImpl(s, outputs[0]);
} else if (stype == kCSRStorage) {
FillZerosCsrImpl(s, outputs[0]);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}
template<typename xpu>
void EyeFill(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 0U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const EyeParam& param = nnvm::get<EyeParam>(attrs.parsed);
const TBlob& out_data = outputs[0];
const nnvm::dim_t num_cols = param.M > 0 ? param.M : param.N;
const nnvm::dim_t cnnz = std::max(num_cols - std::abs(param.k), (nnvm::dim_t)0);
const nnvm::dim_t rnnz = std::max(param.N - std::abs(param.k), (nnvm::dim_t)0);
const nnvm::dim_t nnz = param.k > 0 ? std::min(cnnz, param.N) :
std::min(rnnz, num_cols);
using namespace mxnet_op;
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Fill(s, out_data, req[0], static_cast<DType>(0));
if (nnz > 0) {
Kernel<eye_dns_fill<req_type>, xpu>::Launch(s, nnz, out_data.dptr<DType>(),
std::max(static_cast<nnvm::dim_t>(0), param.k), param.k, num_cols);
}
});
});
}
struct range_fwd {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, index_t repeat, DType start, DType step,
int req, DType* out) {
KERNEL_ASSIGN(out[i], req, start + (i/repeat) * step);
}
};
template<typename xpu, typename ParamType>
void RangeCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
// Force unsigned params to take two's complement form on ARM to ensure consistency with x86
// results. Casting negative floats to unsigned types is undefined in the CPP standard.
auto step = std::is_signed<DType>() ? param.step : static_cast<index_t>(param.step);
auto start = std::is_signed<DType>() ? param.start : static_cast<index_t>(param.start);
Kernel<range_fwd, xpu>::Launch(s,
outputs[0].Size(),
static_cast<int>(param.repeat),
static_cast<DType>(start),
static_cast<DType>(step),
req[0],
outputs[0].dptr<DType>());
});
}
inline bool RangeShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_NE(param.step, 0)
<< "Range does not support step=0, received " << param.step;
CHECK(param.repeat > 0)
<< "Range only supports repeat > 0, received " << param.repeat;
if (param.infer_range && !param.stop.has_value()) {
return false;
}
if (param.step > 0) {
CHECK(param.start < param.stop.value())
<< "Invalid range (start, stop, step) = "
<< "(" << param.start << "," << param.stop.value() << "," << param.step << ")";
} else {
CHECK(param.start > param.stop.value())
<< "Invalid range (start, stop, step)= "
<< "(" << param.start << "," << param.stop.value() << "," << param.step << ")";
}
const double out_size = std::ceil((param.stop.value() - param.start) / param.step)
* param.repeat;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
return true;
}
struct linspace_fwd {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, double start, double stop, double step,
int req, DType* out) {
KERNEL_ASSIGN(out[i], req, static_cast<DType>(start + step * i));
}
};
template<typename xpu>
void LinspaceCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
const LinspaceParam& param = nnvm::get<LinspaceParam>(attrs.parsed);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
int step_num = param.endpoint ? param.num - 1 : param.num;
double step = step_num > 0 ? (param.stop - param.start) / step_num : 0.0f;
Kernel<linspace_fwd, xpu>::Launch(s,
outputs[0].Size(),
param.start,
param.stop,
step,
req[0],
outputs[0].dptr<DType>());
});
}
inline bool LinspaceShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const LinspaceParam& param = nnvm::get<LinspaceParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_GE(param.num, 0)
<< "Number of sequence should be non-negative, received " << param.num;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({static_cast<nnvm::dim_t>(param.num)}));
return true;
}
inline bool RangeLikeShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const RangeLikeParam& param = nnvm::get<RangeLikeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
int real_axis = -1;
if (param.axis.has_value()) {
real_axis = param.axis.value() < 0 ?
(param.axis.value() + (*in_attrs)[0].ndim()) : param.axis.value();
CHECK(real_axis >=0 && real_axis < (*in_attrs)[0].ndim())
<< "cannot handle param.axis " << param.axis.value() << ".";
}
if (real_axis == -1) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
} else {
const index_t out_size = (*in_attrs)[0][real_axis];
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
}
return true;
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_INIT_OP_H_