blob: 7b7b3646ff94c2b46aa9be3f70a8cb2b7d27815d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file np_matrix_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_matrix_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include <vector>
#include "../utils.h"
#include "../../../operator/nn/concat-inl.h"
#include "../../../operator/tensor/matrix_op-inl.h"
#include "../../../operator/numpy/np_matrix_op-inl.h"
namespace mxnet {
MXNET_REGISTER_API("_npi.transpose")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
static const nnvm::Op* op = Op::Get("_npi_transpose");
nnvm::NodeAttrs attrs;
op::NumpyTransposeParam param = {};
if (args[1].type_code() == kNull) {
param.axes = TShape(-1, 0);
} else if (args[1].type_code() == kDLInt) {
param.axes = TShape(1, args[1].operator int64_t());
} else {
param.axes = TShape(args[1].operator ObjectRef());
}
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyTransposeParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.expand_dims")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_expand_dims");
nnvm::NodeAttrs attrs;
op::ExpandDimParam param = {};
param.axis = args[1].operator int();
// we directly copy ExpandDimParam, which is trivially-copyable
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::ExpandDimParam>(&attrs);
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.stack").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_stack");
nnvm::NodeAttrs attrs;
op::StackParam param = {};
int i = 0;
int num_inputs = 0;
std::vector<NDArray*> inputs;
while (args[i].type_code() != kDLInt) {
inputs.push_back(args[i].operator mxnet::NDArray*());
i++;
num_inputs++;
}
param.num_args = i;
param.axis = args[i].operator int64_t();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::StackParam>(&attrs);
NDArray* out = args[i + 1].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, outputs);
if (out) {
*ret = PythonArg(i + 1);
} else {
*ret = ndoutputs[0];
}
});
MXNET_REGISTER_API("_npi.flip").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_flip");
nnvm::NodeAttrs attrs;
op::FlipParam param = {};
NDArray* out = args[2].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
if (args[1].type_code() == kNull) {
param.axis = mxnet::Tuple<int>(-1, dim_t(0));
} else if (args[1].type_code() == kDLInt) {
param.axis = Tuple<int>(1, args[1].operator int64_t());
} else {
param.axis = Tuple<int>(args[1].operator ObjectRef());
}
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::FlipParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(2);
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
});
MXNET_REGISTER_API("_npi.concatenate")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_concatenate");
nnvm::NodeAttrs attrs;
op::ConcatParam param = {};
int arg_size = args.num_args;
param.num_args = arg_size - 2;
if (args[arg_size - 2].type_code() == kNull) {
param.dim = dmlc::nullopt;
} else {
param.dim = args[arg_size - 2].operator int();
}
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::ConcatParam>(&attrs);
int num_inputs = arg_size - 2;
std::vector<NDArray*> inputs;
inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
inputs.push_back(args[i].operator mxnet::NDArray*());
}
NDArray* out = args[arg_size - 1].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, outputs);
if (out) {
*ret = PythonArg(arg_size - 1);
} else {
*ret = ndoutputs[0];
}
});
MXNET_REGISTER_API("_npi.dstack")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_dstack");
nnvm::NodeAttrs attrs;
op::ConcatParam param = {};
int args_size = args.size();
// param.num_args
param.num_args = args_size;
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::ConcatParam>(&attrs);
// inputs
int num_inputs = args_size;
std::vector<NDArray*> inputs_vec(args_size, nullptr);
for (int i = 0; i < args_size; ++i) {
inputs_vec[i] = args[i].operator mxnet::NDArray*();
}
NDArray** inputs = inputs_vec.data();
// outputs
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.split").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_split");
int num_inputs = 1;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
nnvm::NodeAttrs attrs;
op::SplitParam param = {};
param.axis = args[2].operator int();
param.squeeze_axis = false;
if (args[1].type_code() == kDLInt) {
param.indices = TShape(0, 0);
param.sections = args[1].operator int();
int index = param.axis >= 0 ? param.axis : param.axis + inputs[0]->shape().ndim();
CHECK_GE(index, 0) << "IndexError: tuple index out of range";
CHECK_GT(param.sections, 0) << "ValueError: number sections must be larger than 0";
CHECK_EQ(inputs[0]->shape()[index] % param.sections, 0)
<< "ValueError: array split does not result in an equal division";
} else {
TShape t = TShape(args[1].operator ObjectRef());
param.indices = TShape(t.ndim() + 1, 0);
for (int i = 0; i < t.ndim(); ++i) {
param.indices[i + 1] = t[i];
}
param.sections = 0;
}
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::SplitParam>(&attrs);
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
std::vector<NDArrayHandle> ndarray_handles;
ndarray_handles.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
ndarray_handles.emplace_back(ndoutputs[i]);
}
*ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end());
});
MXNET_REGISTER_API("_npi.roll").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
static const nnvm::Op* op = Op::Get("_npi_roll");
nnvm::NodeAttrs attrs;
op::NumpyRollParam param = {};
if (args[1].type_code() == kNull) {
param.shift = dmlc::nullopt;
} else if (args[1].type_code() == kDLInt) {
param.shift = TShape(1, args[1].operator int64_t());
} else {
param.shift = TShape(args[1].operator ObjectRef());
}
if (args[2].type_code() == kNull) {
param.axis = dmlc::nullopt;
} else if (args[2].type_code() == kDLInt) {
param.axis = TShape(1, args[2].operator int64_t());
} else {
param.axis = TShape(args[2].operator ObjectRef());
}
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyRollParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.rot90").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
static const nnvm::Op* op = Op::Get("_npi_rot90");
nnvm::NodeAttrs attrs;
op::NumpyRot90Param param = {};
param.k = args[1].operator int();
if (args[2].type_code() == kNull) {
param.axes = dmlc::nullopt;
} else if (args[2].type_code() == kDLInt) {
param.axes = TShape(1, args[2].operator int64_t());
} else {
param.axes = TShape(args[2].operator ObjectRef());
}
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyRot90Param>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.column_stack")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_column_stack");
nnvm::NodeAttrs attrs;
op::NumpyColumnStackParam param = {};
param.num_args = args.size();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyColumnStackParam>(&attrs);
int num_outputs = 0;
std::vector<NDArray*> inputs;
inputs.reserve(param.num_args);
for (int i = 0; i < param.num_args; ++i) {
inputs.push_back(args[i].operator mxnet::NDArray*());
}
auto ndoutputs = Invoke(op, &attrs, param.num_args, &inputs[0], &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.hstack")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_hstack");
nnvm::NodeAttrs attrs;
op::ConcatParam param = {};
param.num_args = args.size();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::ConcatParam>(&attrs);
int num_outputs = 0;
std::vector<NDArray*> inputs;
inputs.reserve(param.num_args);
for (int i = 0; i < param.num_args; ++i) {
inputs.push_back(args[i].operator mxnet::NDArray*());
}
auto ndoutputs = Invoke(op, &attrs, param.num_args, &inputs[0], &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.array_split")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
static const nnvm::Op* op = Op::Get("_npi_array_split");
nnvm::NodeAttrs attrs;
op::SplitParam param = {};
param.axis = args[2].operator int();
param.squeeze_axis = false;
if (args[1].type_code() == kDLInt) {
param.indices = TShape(0, 0);
param.sections = args[1].operator int();
CHECK_GT(param.sections, 0) << "ValueError: number sections must be larger than 0";
} else {
TShape t = TShape(args[1].operator ObjectRef());
param.indices = TShape(t.ndim() + 1, 0);
for (int i = 0; i < t.ndim(); ++i) {
param.indices[i + 1] = t[i];
}
param.sections = 0;
}
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::SplitParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
std::vector<NDArrayHandle> ndarray_handles;
ndarray_handles.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
ndarray_handles.emplace_back(ndoutputs[i]);
}
*ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end());
});
MXNET_REGISTER_API("_npi.dsplit")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
static const nnvm::Op* op = Op::Get("_npi_split");
int num_inputs = 1;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
CHECK_GE(inputs[0]->shape().ndim(), 3)
<< "ValueError: dsplit only works on arrays of 3 or more dimensions";
nnvm::NodeAttrs attrs;
op::SplitParam param = {};
param.axis = 2;
param.squeeze_axis = false;
if (args[1].type_code() == kDLInt) {
param.indices = TShape(0, 0);
param.sections = args[1].operator int();
CHECK_EQ(inputs[0]->shape()[2] % param.sections, 0)
<< "ValueError: array split does not result in an equal division";
CHECK_GT(param.sections, 0) << "ValueError: number sections must be larger than 0";
} else {
TShape t = TShape(args[1].operator ObjectRef());
param.indices = TShape(t.ndim() + 1, 0);
for (int i = 0; i < t.ndim(); ++i) {
param.indices[i + 1] = t[i];
}
param.sections = 0;
}
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::SplitParam>(&attrs);
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
std::vector<NDArrayHandle> ndarray_handles;
ndarray_handles.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
ndarray_handles.emplace_back(ndoutputs[i]);
}
*ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end());
});
MXNET_REGISTER_API("_npi.hsplit")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
static const nnvm::Op* op = Op::Get("_npi_hsplit");
int num_inputs = 1;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
CHECK_GE(inputs[0]->shape().ndim(), 1)
<< "ValueError: hsplit only works on arrays of 1 or more dimensions";
nnvm::NodeAttrs attrs;
op::SplitParam param = {};
param.axis = 0;
param.squeeze_axis = false;
if (args[1].type_code() == kDLInt) {
param.indices = TShape(0, 0);
param.sections = args[1].operator int();
CHECK_GT(param.sections, 0) << "ValueError: number sections must be larger than 0";
} else {
TShape t = TShape(args[1].operator ObjectRef());
param.indices = TShape(t.ndim() + 1, 0);
for (int i = 0; i < t.ndim(); ++i) {
param.indices[i + 1] = t[i];
}
param.sections = 0;
}
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::SplitParam>(&attrs);
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
std::vector<NDArrayHandle> ndarray_handles;
ndarray_handles.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
ndarray_handles.emplace_back(ndoutputs[i]);
}
*ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end());
});
MXNET_REGISTER_API("_npi.vsplit")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
static const nnvm::Op* op = Op::Get("_npi_split");
int num_inputs = 1;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
CHECK_GE(inputs[0]->shape().ndim(), 2)
<< "ValueError: vsplit only works on arrays of 2 or more dimensions";
nnvm::NodeAttrs attrs;
op::SplitParam param = {};
param.axis = 0;
param.squeeze_axis = false;
if (args[1].type_code() == kDLInt) {
param.indices = TShape(0, 0);
param.sections = args[1].operator int();
CHECK_EQ(inputs[0]->shape()[0] % param.sections, 0)
<< "ValueError: array split does not result in an equal division";
CHECK_GT(param.sections, 0) << "ValueError: number sections must be larger than 0";
} else {
TShape t = TShape(args[1].operator ObjectRef());
param.indices = TShape(t.ndim() + 1, 0);
for (int i = 0; i < t.ndim(); ++i) {
param.indices[i + 1] = t[i];
}
param.sections = 0;
}
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::SplitParam>(&attrs);
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
std::vector<NDArrayHandle> ndarray_handles;
ndarray_handles.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
ndarray_handles.emplace_back(ndoutputs[i]);
}
*ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end());
});
MXNET_REGISTER_API("_npi.diag").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_diag");
nnvm::NodeAttrs attrs;
op::NumpyDiagParam param = {};
if (features::is_enabled(features::INT64_TENSOR_SIZE))
param.k = args[1].operator int64_t();
else
param.k = args[1].operator int();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyDiagParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.rollaxis")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_rollaxis");
nnvm::NodeAttrs attrs;
op::NumpyRollaxisParam param = {};
param.axis = args[1].operator int();
param.start = args[2].operator int();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyRollaxisParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.reshape")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_reshape");
nnvm::NodeAttrs attrs;
op::NumpyXReshapeParam param = {};
if (args[1].type_code() == kNull) {
param.newshape = TShape(-1, 0);
} else if (args[1].type_code() == kDLInt) {
param.newshape = TShape(1, args[1].operator int64_t());
} else {
param.newshape = TShape(args[1].operator ObjectRef());
}
param.reverse = args[2].operator bool();
param.order = args[3].operator std::string();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyXReshapeParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.moveaxis")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_moveaxis");
nnvm::NodeAttrs attrs;
op::NumpyMoveaxisParam param = {};
if (args[1].type_code() == kNull) {
param.source = TShape(-1, 0);
} else if (args[1].type_code() == kDLInt) {
param.source = TShape(1, args[1].operator int64_t());
} else {
param.source = TShape(args[1].operator ObjectRef());
}
if (args[2].type_code() == kNull) {
param.destination = TShape(-1, 0);
} else if (args[2].type_code() == kDLInt) {
param.destination = TShape(1, args[2].operator int64_t());
} else {
param.destination = TShape(args[2].operator ObjectRef());
}
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyMoveaxisParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.diagonal")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_diagonal");
nnvm::NodeAttrs attrs;
op::NumpyDiagonalParam param = {};
if (features::is_enabled(features::INT64_TENSOR_SIZE))
param.offset = args[1].operator int64_t();
else
param.offset = args[1].operator int();
param.axis1 = args[2].operator int();
param.axis2 = args[3].operator int();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyDiagonalParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.diag_indices_from")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_diag_indices_from");
nnvm::NodeAttrs attrs;
attrs.op = op;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.diagflat")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_diagflat");
nnvm::NodeAttrs attrs;
op::NumpyDiagflatParam param = {};
param.k = args[1].operator int();
int num_inputs = 1;
int num_outputs = 0;
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyDiagflatParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.squeeze")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_squeeze");
nnvm::NodeAttrs attrs;
op::SqueezeParam param = {};
if (args[1].type_code() == kNull) {
param.axis = dmlc::optional<mxnet::Tuple<int>>();
} else if (args[1].type_code() == kDLInt) {
param.axis = Tuple<int>(1, args[1].operator int64_t());
} else {
param.axis = Tuple<int>(args[1].operator ObjectRef());
}
int num_inputs = 1;
int num_outputs = 0;
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::SqueezeParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
MXNET_REGISTER_API("_npi.tril_indices")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_tril_indices");
nnvm::NodeAttrs attrs;
op::NumpyTrilindicesParam param = {};
if (features::is_enabled(features::INT64_TENSOR_SIZE)) {
param.n = args[0].operator int64_t();
param.k = args[1].operator int64_t();
param.m = args[2].operator int64_t();
} else {
param.n = args[0].operator int();
param.k = args[1].operator int();
param.m = args[2].operator int();
}
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyTrilindicesParam>(&attrs);
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr);
std::vector<NDArrayHandle> ndarray_handles;
ndarray_handles.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
ndarray_handles.emplace_back(ndoutputs[i]);
}
*ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end());
});
MXNET_REGISTER_API("_npi.vstack")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_vstack");
nnvm::NodeAttrs attrs;
op::NumpyVstackParam param = {};
param.num_args = args.size();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyVstackParam>(&attrs);
int num_outputs = 0;
std::vector<NDArray*> inputs_vec(args.size(), nullptr);
for (int i = 0; i < args.size(); ++i) {
inputs_vec[i] = args[i].operator mxnet::NDArray*();
}
NDArray** inputs = inputs_vec.data();
auto ndoutputs = Invoke(op, &attrs, param.num_args, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
} // namespace mxnet