blob: fd32d9bc7bc1c6f46fa008b6866b4db5e02a827d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file np_insert_op_slice-inl.h
* \brief Function definition of insert operators (insert by slice index)
*/
#ifndef MXNET_OPERATOR_NUMPY_NP_INSERT_OP_SLICE_INL_H_
#define MXNET_OPERATOR_NUMPY_NP_INSERT_OP_SLICE_INL_H_
#include <vector>
#include <algorithm>
#include "./np_insert_op-inl.h"
namespace mxnet {
namespace op {
/*
* Only support slice index (the type of param 'obj' is slice).
*/
template<typename xpu>
void NumpyInsertSliceCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
const NumpyInsertParam& param = nnvm::get<NumpyInsertParam>(attrs.parsed);
CHECK_EQ(inputs.size(), (param.val.has_value() ? 1 : 2));
CHECK_EQ(outputs.size(), 1);
CHECK_EQ(req.size(), 1);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const int arr_pos = 0;
const int val_pos = param.val.has_value() ? 0 : 1;
const int out_pos = 0;
int ndim = inputs[arr_pos].shape_.ndim();
int axis = param.axis.has_value() ? param.axis.value() : 0;
TBlob arr;
TBlob values = param.val.has_value() ?
TBlob(nullptr, mxnet::TShape(0, 1), xpu::kDevMask, outputs[out_pos].type_flag_) :
inputs[val_pos];
if (!param.axis.has_value()) {
arr = inputs[arr_pos].reshape(Shape1(inputs[arr_pos].shape_.Size()));
ndim = 1;
} else if (ndim == 0) {
if (param.val.has_value()) {
CHECK_EQ(inputs[val_pos].shape_.ndim(), 0)
<< "'arr' is a 0-d array, 'values' can not assign to it. "
<< "alueError: assignment to 0-d array.";
mxnet_op::copy(s, outputs[out_pos], inputs[val_pos]);
} else {
MSHADOW_TYPE_SWITCH(outputs[out_pos].type_flag_, DType, {
Fill(s, outputs[out_pos], req[0], static_cast<DType>(param.val.value()));
});
}
return;
} else {
arr = inputs[arr_pos];
CHECK(axis >= -1 * arr.shape_.ndim() && axis < arr.shape_.ndim())
<< "Axis should be in the range of [-r, r-1] where r is the rank of input tensor";
axis += (axis < 0) ? arr.shape_.ndim() : 0;
}
index_t N = arr.shape_[axis];
size_t indices_len = 0; // indices amount
index_t start = 0, stop = 0, step = 0; // arguments from 'obj' when it's 'slice'
// get and check indices from slice or sequence of ints
SliceIndices(param.start, param.stop, param.step,
N, &start, &stop, &step, &indices_len);
int numnew = 0; // numnew = output.shape[axis] - arr.shape[axis]
int index = 0; // save modified index, because index may be negative integer
mxnet::TShape val_newshape(arr.shape_.ndim(), -1);
// modify values's ndim to arr's ndim, for broadcast easily later
// e.g. value shape: (2,) arr shape: (3, 2) => value shape: (1, 2)
for (int i = values.shape_.ndim() - 1, j = arr.shape_.ndim() - 1;
i >= 0 || j >= 0;
--i, --j) {
if (i >= 0 && j >= 0) {
val_newshape[j] = values.shape_[i];
} else if (i >= 0) {
CHECK_EQ(values.shape_[i], 1) << "index exceed limits.";
} else {
val_newshape[j] = 1;
}
}
values.shape_.assign(val_newshape.begin(), val_newshape.end());
// get numnew
mxnet::TShape old_valshape(values.shape_);
if (indices_len == 1) { // tensor with only one element
numnew = values.shape_[axis];
index = start;
CHECK(index >= -1 * N && index <= N)
<< "Index should be in the range of [-r, r-1] where r is the dim size in 'axis'";
if (index < 0) {
index += N;
}
} else {
numnew = static_cast<index_t>(indices_len);
}
const mxnet::TShape& outshape = outputs[out_pos].shape_;
int dtype = outputs[out_pos].type_flag_;
int vtype = param.val.has_value() ?
mshadow::DataType<double>::kFlag :
inputs[val_pos].type_flag_;
if ((indices_len == 1) && param.val.has_value()) {
MSHADOW_TYPE_SWITCH(vtype, VType, {
// If insert use single index and 'value' is inputed as numerical parameter
values = TBlob(ctx.requested[0].get_space_typed<xpu, 1, VType>(Shape1(1), s));
Fill(s, values, kWriteTo, param.val.value());
});
}
if (indices_len == 1) {
MXNET_NDIM_SWITCH(outshape.ndim(), ndim, {
InsertScalerImpl<xpu, ndim>(s, outputs[out_pos], arr, values,
mxnet_op::calc_stride(arr.shape_.get<ndim>()),
mxnet_op::calc_stride(values.shape_.get<ndim>()),
mxnet_op::calc_stride(old_valshape.get<ndim>()),
mxnet_op::calc_stride(outshape.get<ndim>()),
outshape.get<ndim>(), values.shape_.get<ndim>(),
dtype, vtype, req[out_pos], axis, start, numnew,
outshape.Size(), false);
});
} else {
// broadcast check
for (int i = outshape.ndim() - 1; i >= 0; --i) {
index_t sz = outshape[i];
if (i == axis) {
sz = numnew;
}
CHECK((values.shape_[i] == 1) || (values.shape_[i] == sz));
}
size_t temp_storage_bytes, temp_mem_size;
temp_storage_bytes = SortByKeyWorkspaceSize<int64_t, index_t, xpu>(indices_len, false, true);
temp_mem_size = indices_len * sizeof(int64_t) * 2 +
indices_len * sizeof(index_t) +
outshape[axis] * sizeof(index_t) * 2 +
temp_storage_bytes;
Tensor<xpu, 1, char> temp_mem =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_mem_size), s);
int64_t* indices_ptr = reinterpret_cast<int64_t*>(temp_mem.dptr_);
int64_t* sorted_indices_ptr = reinterpret_cast<int64_t*>(indices_ptr + indices_len);
index_t* order_ptr = reinterpret_cast<index_t*>(sorted_indices_ptr + indices_len);
index_t* is_insert = reinterpret_cast<index_t*>(order_ptr + indices_len);
index_t* origin_idx = reinterpret_cast<index_t*>(is_insert + outshape[axis]);
Tensor<xpu, 1, char> temp_storage(reinterpret_cast<char*>(origin_idx + outshape[axis]),
Shape1(temp_storage_bytes), s);
Tensor<xpu, 1, int64_t> indices(indices_ptr, Shape1(indices_len), s);
Tensor<xpu, 1, int64_t> sorted_indices(sorted_indices_ptr, Shape1(indices_len), s);
Tensor<xpu, 1, index_t> order(order_ptr, Shape1(indices_len), s);
int num_bits = 8 * sizeof(int64_t);
Kernel<SliceToIndices, xpu>::Launch(s, indices_len, indices_ptr, start, step);
Kernel<range_fwd, xpu>::Launch(s, indices_len, index_t{1}, index_t{0}, index_t{1},
kWriteTo, order_ptr);
mxnet::op::SortByKey(indices, order, true, &temp_storage, 0, num_bits, &sorted_indices);
Kernel<IndicesModify, xpu>::Launch(s, indices_len, indices_ptr, order_ptr);
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(s, outshape[axis], is_insert);
Kernel<SetIsInsert, xpu>::Launch(s, indices_len, indices_ptr, is_insert);
Kernel<SetOriginValuesIdx, xpu>::Launch(s, indices_len, indices_ptr, origin_idx);
Kernel<SetOriginArrIdx, xpu>::Launch(s, outshape[axis], is_insert, origin_idx);
MXNET_NDIM_SWITCH(outshape.ndim(), ndim, {
InsertSequenceImpl<xpu, ndim>(s, outputs[out_pos], arr, values,
mxnet_op::calc_stride(arr.shape_.get<ndim>()),
mxnet_op::calc_stride(values.shape_.get<ndim>()),
mxnet_op::calc_stride(outshape.get<ndim>()),
outshape.get<ndim>(), values.shape_.get<ndim>(),
is_insert, origin_idx, dtype, vtype, req[out_pos],
axis, outshape.Size());
});
}
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NUMPY_NP_INSERT_OP_SLICE_INL_H_