blob: 96ac4426270a5dd45f4106bf14db931e4d0c557e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_
#define MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_
#if MXNET_USE_TENSORRT
#include <map>
#include <string>
#include <vector>
#include "./graph_executor.h"
namespace mxnet {
namespace exec {
class TrtGraphExecutor : public GraphExecutor {
public:
static Executor* TensorRTBind(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& group2ctx,
std::vector<Context> *in_arg_ctxes,
std::vector<Context>* arg_grad_ctxes,
std::vector<Context>* aux_state_ctxes,
std::unordered_map<std::string, TShape>* arg_shape_map,
std::unordered_map<std::string, int>* arg_dtype_map,
std::unordered_map<std::string, int>* arg_stype_map,
std::vector<OpReqType>* grad_req_types,
const std::unordered_set<std::string>& param_names,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grads,
std::vector<NDArray>* aux_states,
std::unordered_map<std::string, NDArray>*
shared_data_arrays = nullptr,
Executor* shared_exec = nullptr);
virtual void Init(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
std::vector<Context> *in_arg_ctxes,
std::vector<Context> *arg_grad_ctxes,
std::vector<Context> *aux_state_ctxes,
std::unordered_map<std::string, TShape> *arg_shape_map,
std::unordered_map<std::string, int> *arg_dtype_map,
std::unordered_map<std::string, int> *arg_stype_map,
std::vector<OpReqType> *grad_req_types,
const std::unordered_set<std::string>& shared_arg_names,
std::vector<NDArray>* in_arg_vec,
std::vector<NDArray>* arg_grad_vec,
std::vector<NDArray>* aux_state_vec,
std::unordered_map<std::string, NDArray>* shared_buffer = nullptr,
Executor* shared_exec = nullptr,
const nnvm::NodeEntryMap<NDArray>& feed_dict
= nnvm::NodeEntryMap<NDArray>());
// Returns symbol representing the TRT optimized graph for comparison purposes.
nnvm::Symbol GetOptimizedSymbol();
protected:
Graph ReinitGraph(Graph&& g, const Context &default_ctx,
const std::map<std::string, Context> &ctx_map,
std::vector<Context> *in_arg_ctxes,
std::vector<Context> *arg_grad_ctxes,
std::vector<Context> *aux_state_ctxes,
std::vector<OpReqType> *grad_req_types,
std::unordered_map<std::string, TShape> *arg_shape_map,
std::unordered_map<std::string, int> *arg_dtype_map,
std::unordered_map<std::string, int> *arg_stype_map,
std::unordered_map<std::string, NDArray> *params_map);
void InitArguments(const nnvm::IndexedGraph& idx,
const nnvm::ShapeVector& inferred_shapes,
const nnvm::DTypeVector& inferred_dtypes,
const StorageTypeVector& inferred_stypes,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& shared_arg_names,
const Executor* shared_exec,
std::unordered_map<std::string, NDArray>* shared_buffer,
std::vector<NDArray>* in_arg_vec,
std::vector<NDArray>* arg_grad_vec,
std::vector<NDArray>* aux_state_vec) override;
};
} // namespace exec
} // namespace mxnet
#endif // MXNET_USE_TENSORRT
#endif // MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_