blob: 3e1c345d59c344a96d56df2637d046ab5de5384c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file np_init_op.h
* \brief Function definition of numpy init op
*/
#ifndef MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_
#define MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_
#include <vector>
#include <string>
#include "../tensor/init_op.h"
#include "../tensor/elemwise_unary_op.h"
namespace mxnet {
namespace op {
struct IndicesOpParam : public dmlc::Parameter<IndicesOpParam> {
mxnet::TShape dimensions;
int dtype;
std::string ctx;
DMLC_DECLARE_PARAMETER(IndicesOpParam) {
DMLC_DECLARE_FIELD(dimensions)
.describe("The shape of the grid.");
DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kInt32)
MXNET_ADD_ALL_TYPES
.describe("Target data type.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
}
};
template<int req>
struct indices_fwd {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType* out,
const nnvm::dim_t value,
const nnvm::dim_t N,
const nnvm::dim_t dim_i,
const nnvm::dim_t j,
const nnvm::dim_t k,
const nnvm::dim_t t) {
KERNEL_ASSIGN(out[dim_i*N+N/(t*value)*j+i+k*N/t], req, static_cast<DType>(j));
}
};
template<int req>
struct identity {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const int n) {
using namespace mxnet_op;
const index_t row_id = i / n;
const index_t col_id = i % n;
if (row_id == col_id) {
KERNEL_ASSIGN(out_data[i], req, static_cast<DType>(1));
} else {
KERNEL_ASSIGN(out_data[i], req, static_cast<DType>(0));
}
}
};
template<typename xpu>
void IndicesCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 0U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
const IndicesOpParam& param = nnvm::get<IndicesOpParam>(attrs.parsed);
const TBlob& out_data = outputs[0];
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
dim_t indim = param.dimensions.ndim();
dim_t t = 1;
dim_t N = out_data.Size()/indim;
dim_t value = 0;
if (out_data.Size() == 0) return;
if (req[0] != kNullOp) {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
for (int i = 0; i < indim; ++i) {
value = param.dimensions[i];
for (int k = 0; k < t; ++k) {
for (int j = 0; j < param.dimensions[i]; ++j) {
Kernel<indices_fwd<req_type>, xpu>::Launch(s, N/(param.dimensions[i] * t),
out_data.dptr<DType>(), value, N, i, j, k, t);
}
}
t = t * param.dimensions[i];
}
});
});
}
}
template<typename xpu>
void IdentityCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 0U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& out_data = outputs[0];
int n = out_data.shape_[0];
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<identity<req_type>, xpu>::Launch(
s, out_data.Size(), out_data.dptr<DType>(), n);
});
});
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_