blob: d83fce00704aa7127611ca164fa3d3070a162f7d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2017 by Contributors
* \file np_indexing_op.h
* \brief Function definition of numpy indexing operator
*/
#ifndef MXNET_OPERATOR_NUMPY_NP_INDEXING_OP_H_
#define MXNET_OPERATOR_NUMPY_NP_INDEXING_OP_H_
#include <vector>
#include "../contrib/boolean_mask-inl.h"
#include "../tensor/indexing_op.h"
#include "../tensor/broadcast_reduce_op.h"
#ifdef __CUDACC__
#include "../tensor/indexing_op-inl.cuh"
#endif
namespace mxnet {
namespace op {
namespace np_indexing_ { // to avoid name conflict
enum Inputs {kArr, kIdx};
enum Outputs {kOut};
} // namespace np_indexing_
struct AdvancedIndexingMultipleParam: public dmlc::Parameter<AdvancedIndexingMultipleParam> {
int axis;
DMLC_DECLARE_PARAMETER(AdvancedIndexingMultipleParam) {
DMLC_DECLARE_FIELD(axis)
.set_default(0)
.describe("The axis of tuple type indexing");
}
};
template<typename xpu>
void AdvancedIndexingOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
// template<typename xpu>
// void AdvancedIndexingMultipleOpForward(const nnvm::NodeAttrs& attrs,
// const OpContext& ctx,
// const std::vector<NDArray>& inputs,
// const std::vector<OpReqType>& req,
// const std::vector<NDArray>& outputs);
template<typename xpu>
void AdvancedIndexingOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs);
// template<typename xpu>
// void AdvancedIndexingMultipleOpBackward(const nnvm::NodeAttrs& attrs,
// const OpContext& ctx,
// const std::vector<NDArray>& inputs,
// const std::vector<OpReqType>& req,
// const std::vector<NDArray>& outputs);
template<typename xpu>
void AdvancedIndexingMultipleBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using nnvm::dim_t;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
if (req[0] == kNullOp) return;
if (inputs[np_indexing_::kIdx].type_flag_ == mshadow::kBool) {
LOG(FATAL)
<< "Multi-dimension boolean indexing is not supported.";
} else if (inputs[np_indexing_::kIdx].type_flag_ == mshadow::kInt8 ||
inputs[np_indexing_::kIdx].type_flag_ == mshadow::kInt16 ||
inputs[np_indexing_::kIdx].type_flag_ == mshadow::kInt32 ||
inputs[np_indexing_::kIdx].type_flag_ == mshadow::kInt64) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const mxnet::TShape& oshape = outputs[0].shape_;
const mxnet::TShape& ishape = inputs[1].shape_;
dim_t M = ishape[0];
dim_t N = ishape.Size() / M;
dim_t K = oshape.ProdShape(M, oshape.ndim());
mshadow::Shape<10> strides;
for (dim_t i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride;
if (kWriteTo == req[0]) {
Fill<true>(s, outputs[0], req[0], 0);
}
MXNET_NO_INT8_TYPE_SWITCH(inputs[0].type_flag_, DType, { // output data type switch
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch
GatherNDBackwardImpl(N, M, K, strides,
outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<IType>(),
s);
});
});
} else {
LOG(FATAL)
<< "arrays used as indices must be explictly declared as integer (or boolean) type."
<< "Use np.astype() to cast indices to integer or boolean.";
}
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NUMPY_NP_INDEXING_OP_H_