| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * Copyright (c) 2015 by Contributors |
| * \file native_op-inl.h |
| * \brief |
| * \author Junyuan Xie |
| */ |
| |
| #ifndef MXNET_OPERATOR_CUSTOM_NDARRAY_OP_INL_H_ |
| #define MXNET_OPERATOR_CUSTOM_NDARRAY_OP_INL_H_ |
| #include <dmlc/logging.h> |
| #include <dmlc/parameter.h> |
| #include <mxnet/operator.h> |
| #include <mxnet/c_api.h> |
| #include <map> |
| #include <vector> |
| #include <string> |
| #include <utility> |
| #include <sstream> |
| #include "../operator_common.h" |
| |
| namespace mxnet { |
| namespace op { |
| |
| struct NDArrayOpParam : public dmlc::Parameter<NDArrayOpParam> { |
| void *info; |
| |
| NDArrayOpInfo *pinfo; |
| int num_inputs_, num_outputs_; |
| DMLC_DECLARE_PARAMETER(NDArrayOpParam) { |
| DMLC_DECLARE_FIELD(info); |
| } |
| }; |
| |
| template<typename xpu> |
| class NDArrayOp : public Operator { |
| public: |
| explicit NDArrayOp(NDArrayOpParam p) { |
| this->param_ = p; |
| } |
| |
| virtual void Forward(const OpContext &ctx, |
| const std::vector<TBlob> &in_data, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &out_data, |
| const std::vector<TBlob> &aux_args); |
| |
| virtual void Backward(const OpContext &ctx, |
| const std::vector<TBlob> &out_grad, |
| const std::vector<TBlob> &in_data, |
| const std::vector<TBlob> &out_data, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &in_grad, |
| const std::vector<TBlob> &aux_args); |
| |
| private: |
| NDArrayOpParam param_; |
| Context get_ctx(); |
| }; // NDArrayOp |
| |
| template<typename xpu> |
| Operator* CreateOp(NDArrayOpParam param); |
| |
| #if DMLC_USE_CXX11 |
| class NDArrayOpProp : public OperatorProperty { |
| public: |
| std::vector<std::string> ListArguments() const override { |
| char ** args = NULL; |
| CHECK(param_.pinfo->list_arguments(&args, param_.pinfo->p_list_arguments)); |
| std::vector<std::string> ret; |
| for (int i = 0; args[i] != NULL; ++i) { |
| ret.emplace_back(args[i]); |
| } |
| return ret; |
| } |
| |
| std::vector<std::string> ListOutputs() const override { |
| char ** args = NULL; |
| CHECK(param_.pinfo->list_outputs(&args, param_.pinfo->p_list_outputs)); |
| std::vector<std::string> ret; |
| for (int i = 0; args[i] != NULL; ++i) { |
| ret.emplace_back(args[i]); |
| } |
| return ret; |
| } |
| |
| int NumOutputs() const override { |
| return param_.num_outputs_; |
| } |
| |
| void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { |
| param_.Init(kwargs); |
| for (auto iter = kwargs.begin(); iter != kwargs.end(); ++iter) { |
| if (iter->first == "info") { |
| sscanf(iter->second.c_str(), "%p", ¶m_.pinfo); |
| } |
| } |
| param_.num_inputs_ = ListArguments().size(); |
| param_.num_outputs_ = ListOutputs().size(); |
| } |
| |
| std::map<std::string, std::string> GetParams() const override { |
| return param_.__DICT__(); |
| } |
| |
| |
| bool InferShape(mxnet::ShapeVector *in_shape, |
| mxnet::ShapeVector *out_shape, |
| mxnet::ShapeVector *aux_shape) const override { |
| std::vector<uint32_t*> shapes; |
| std::vector<int> ndims; |
| size_t size = 0; |
| for (const auto& s : *in_shape) size += s.ndim(); |
| std::vector<uint32_t> shapes_buffer(size); |
| uint32_t *ptr = shapes_buffer.data(); |
| for (const auto& shape : *in_shape) { |
| shapes.push_back(ptr); |
| ndims.push_back(shape.ndim()); |
| ptr = nnvm::ShapeTypeCast(shape.begin(), shape.end(), ptr); |
| } |
| shapes.resize(param_.num_inputs_+param_.num_outputs_); |
| ndims.resize(param_.num_inputs_+param_.num_outputs_); |
| CHECK(param_.pinfo->infer_shape(shapes.size(), ndims.data(), shapes.data(), |
| param_.pinfo->p_infer_shape)); |
| for (unsigned i = 0; i < in_shape->size(); ++i) { |
| SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape(shapes[i], shapes[i]+ndims[i])); |
| } |
| out_shape->clear(); |
| for (unsigned i = param_.num_inputs_; i < shapes.size(); ++i) { |
| out_shape->push_back(mxnet::TShape(shapes[i], shapes[i]+ndims[i])); |
| } |
| return true; |
| } |
| |
| OperatorProperty* Copy() const override { |
| NDArrayOpProp *prop_sym = new NDArrayOpProp(); |
| prop_sym->param_ = this->param_; |
| return prop_sym; |
| } |
| |
| std::string TypeString() const override { |
| return "_NDArray"; |
| } |
| |
| std::vector<int> DeclareBackwardDependency( |
| const std::vector<int> &out_grad, |
| const std::vector<int> &in_data, |
| const std::vector<int> &out_data) const override { |
| int num_dep; |
| int *rdeps; |
| CHECK(param_.pinfo->declare_backward_dependency(out_grad.data(), in_data.data(), |
| out_data.data(), &num_dep, &rdeps, |
| param_.pinfo->p_declare_backward_dependency)); |
| std::vector<int> deps; |
| deps.insert(deps.end(), rdeps, rdeps+num_dep); |
| return deps; |
| } |
| |
| std::vector<std::pair<int, void*> > BackwardInplaceOption( |
| const std::vector<int> &out_grad, |
| const std::vector<int> &in_data, |
| const std::vector<int> &out_data, |
| const std::vector<void*> &in_grad) const override { |
| return {}; |
| } |
| |
| Operator* CreateOperator(Context ctx) const override; |
| |
| ExecType exec_type() const override { |
| return ExecType::kAsync; |
| } |
| |
| private: |
| NDArrayOpParam param_; |
| }; // class PythonProp |
| #endif // DMLC_USE_CXX11 |
| } // namespace op |
| } // namespace mxnet |
| #endif // MXNET_OPERATOR_CUSTOM_NDARRAY_OP_INL_H_ |