blob: ea569162cebde075dd06883f72f806154a11f0bd [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef MXNET_OPERATOR_FUSION_FUSED_OP_H_
#define MXNET_OPERATOR_FUSION_FUSED_OP_H_
#include <mxnet/operator.h>
#include <nnvm/graph.h>
#include <vector>
#include <string>
#include <utility>
#include <mutex>
#include <tuple>
#if MXNET_USE_CUDA
namespace mxnet {
namespace fusion {
enum KernelVariants {
kGeneral,
kShapeOptimized,
kNumKernelVariants // Not a variant- leave this at the end
};
}
struct FusedOpConfig : public dmlc::Parameter<FusedOpConfig> {
int num_inputs;
int num_outputs;
DMLC_DECLARE_PARAMETER(FusedOpConfig) {
DMLC_DECLARE_FIELD(num_inputs).describe("Number of inputs.");
DMLC_DECLARE_FIELD(num_outputs).describe("Number of outputs.");
}
};
struct FusedOpEntry {
FusedOpEntry() : dtype(-1), ndim(-1) {}
int dtype;
int ndim;
};
class FusedOp {
public:
static const int NTHREADS = 512;
explicit FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig& config);
~FusedOp() {}
uint32_t num_inputs() const {
return inputs_.size();
}
uint32_t num_outputs() const {
return outputs_.size();
}
template <typename xpu>
void Forward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs);
bool InferShape(const nnvm::NodeAttrs& attrs,
std::vector<mxnet::TShape>* in_attrs,
std::vector<mxnet::TShape>* out_attrs);
bool InferType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs);
template <typename Attr>
std::tuple<const nnvm::ObjectPtr, std::vector<Attr>, std::vector<Attr>> GetAttrs(
const std::string& attr_name,
const uint32_t node_id);
void ProvideShape(const std::vector<nnvm::ObjectPtr>& nodes,
const std::vector<std::vector<mxnet::TShape>>& in_attrs,
const std::vector<std::vector<mxnet::TShape>>& out_attrs) {
aux_nodes_ = nodes;
aux_in_shapes_ = in_attrs;
aux_out_shapes_ = out_attrs;
}
void ProvideType(const std::vector<nnvm::ObjectPtr>& nodes,
const std::vector<std::vector<int>>& in_attrs,
const std::vector<std::vector<int>>& out_attrs) {
aux_nodes_ = nodes;
aux_in_types_ = in_attrs;
aux_out_types_ = out_attrs;
}
std::tuple<const nnvm::ObjectPtr, std::vector<mxnet::TShape>, std::vector<mxnet::TShape>>
GetAuxShape(const int node_id) const {
return std::make_tuple(aux_nodes_[node_id], aux_in_shapes_[node_id], aux_out_shapes_[node_id]);
}
std::tuple<const nnvm::ObjectPtr, std::vector<int>, std::vector<int>> GetAuxType(
const int node_id) const {
return std::make_tuple(aux_nodes_[node_id], aux_in_types_[node_id], aux_out_types_[node_id]);
}
private:
std::string GenerateCode(const std::vector<OpReqType>& req,
const std::vector<int>& in_dtypes,
const std::vector<int>& out_dtypes,
const std::vector<int>& in_ndims,
const std::vector<int>& out_ndims,
const mxnet::ShapeVector& node_shapes,
const std::vector<int>& node_dtypes,
const int nvec,
const std::string& kernel_name,
std::vector<uint32_t>* check_shapes);
CUfunction CompileCode(const std::string& code, const std::string& kernel_name, int dev_id);
void CheckShapesAndTypes(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
std::vector<int>* in_dtypes,
std::vector<int>* in_ndims,
std::vector<int>* out_dtypes,
std::vector<int>* out_ndims,
int* nvec);
std::vector<FusedOpEntry> inputs_;
std::vector<FusedOpEntry> outputs_;
nnvm::Graph subgraph_;
template <typename T>
struct IntermediateAttr {
std::vector<T> input_attr;
std::vector<T> output_attr;
std::vector<T> internal_attr;
};
// Shapes and types inside the subgraph
// copied here, because a subsequent call
// to InferShape/InferType can overwrite the
// original information stored in subgraph_
// attributes while the previous iterations
// still need them.
std::vector<IntermediateAttr<mxnet::TShape>> intermediate_shapes_;
std::vector<IntermediateAttr<int>> intermediate_dtypes_;
std::vector<nnvm::ObjectPtr> aux_nodes_;
std::vector<std::vector<mxnet::TShape>> aux_in_shapes_;
std::vector<std::vector<mxnet::TShape>> aux_out_shapes_;
std::vector<std::vector<int>> aux_in_types_;
std::vector<std::vector<int>> aux_out_types_;
std::vector<OpReqType> saved_reqs_;
std::vector<uint32_t> extra_shape_args_;
std::vector<uint32_t> check_shape_args_;
CUfunction kernel_functions_[fusion::kNumKernelVariants];
bool initialized_;
int kernel_function_dev_id_;
static std::mutex mutex_;
std::mutex my_mutex_;
};
using FusedOpPtr = std::shared_ptr<FusedOp>;
struct FusedOpHelperParam {
FusedOpPtr op;
uint32_t node_id;
FusedOpHelperParam(FusedOpPtr op, uint32_t node_id) : op(op), node_id(node_id) {}
};
using FusedOpHelperParamPtr = std::shared_ptr<FusedOpHelperParam>;
} // namespace mxnet
#endif // MXNET_USE_CUDA
#endif // MXNET_OPERATOR_FUSION_FUSED_OP_H_