blob: 25efcb67a8be35409a856031007fc63abf0ca791 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file onnx_to_tensorrt.cc
* \brief TensorRT integration with the MXNet executor
* \author Marek Kolodziej, Clement Fuji Tsang
*/
#if MXNET_USE_TENSORRT
#include "./onnx_to_tensorrt.h"
#include <onnx/onnx_pb.h>
#include <NvInfer.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <onnx-tensorrt/NvOnnxParser.h>
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
using std::cerr;
using std::cout;
using std::endl;
namespace onnx_to_tensorrt {
std::string onnx_ir_version_string(int64_t ir_version = onnx::IR_VERSION) {
int onnx_ir_major = ir_version / 1000000;
int onnx_ir_minor = ir_version % 1000000 / 10000;
int onnx_ir_patch = ir_version % 10000;
return (std::to_string(onnx_ir_major) + "." + std::to_string(onnx_ir_minor) + "." +
std::to_string(onnx_ir_patch));
}
void PrintVersion() {
cout << "Parser built against:" << endl;
cout << " ONNX IR version: " << onnx_ir_version_string(onnx::IR_VERSION) << endl;
cout << " TensorRT version: " << NV_TENSORRT_MAJOR << "." << NV_TENSORRT_MINOR << "."
<< NV_TENSORRT_PATCH << endl;
}
std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
unique_ptr<nvonnxparser::IParser>,
std::unique_ptr<TRT_Logger> >
onnxToTrtCtx(const std::string& onnx_model,
int32_t max_batch_size,
size_t max_workspace_size,
nvinfer1::ILogger::Severity verbosity,
bool debug_builder) {
GOOGLE_PROTOBUF_VERIFY_VERSION;
auto trt_logger = std::unique_ptr<TRT_Logger>(new TRT_Logger(verbosity));
auto trt_builder = InferObject(nvinfer1::createInferBuilder(*trt_logger));
const auto explicitBatch =
1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto trt_network = InferObject(trt_builder->createNetworkV2(explicitBatch));
auto trt_parser = InferObject(nvonnxparser::createParser(*trt_network, *trt_logger));
::ONNX_NAMESPACE::ModelProto parsed_model;
// We check for a valid parse, but the main effect is the side effect
// of populating parsed_model
if (!parsed_model.ParseFromString(onnx_model)) {
throw dmlc::Error("Could not parse ONNX from string");
}
if (!trt_parser->parse(onnx_model.c_str(), onnx_model.size())) {
size_t nerror = trt_parser->getNbErrors();
for (size_t i = 0; i < nerror; ++i) {
nvonnxparser::IParserError const* error = trt_parser->getError(i);
if (error->node() != -1) {
::ONNX_NAMESPACE::NodeProto const& node = parsed_model.graph().node(error->node());
cerr << "While parsing node number " << error->node() << " [" << node.op_type();
if (!node.output().empty()) {
cerr << " -> \"" << node.output(0) << "\"";
}
cerr << "]:" << endl;
if (static_cast<int>(verbosity) >= static_cast<int>(nvinfer1::ILogger::Severity::kINFO)) {
cerr << "--- Begin node ---" << endl;
cerr << node.DebugString() << endl;
cerr << "--- End node ---" << endl;
}
}
cerr << "ERROR: " << error->file() << ":" << error->line() << " In function " << error->func()
<< ":\n"
<< "[" << static_cast<int>(error->code()) << "] " << error->desc() << endl;
}
throw dmlc::Error("Cannot parse ONNX into TensorRT Engine");
}
if (dmlc::GetEnv("MXNET_TENSORRT_USE_FP16", true)) {
if (trt_builder->platformHasFastFp16()) {
trt_builder->setFp16Mode(true);
} else {
LOG(WARNING) << "TensorRT can't use fp16 on this platform";
}
}
trt_builder->setMaxBatchSize(max_batch_size);
trt_builder->setMaxWorkspaceSize(max_workspace_size);
trt_builder->setDebugSync(debug_builder);
auto trt_engine = InferObject(trt_builder->buildCudaEngine(*trt_network));
return std::make_tuple(std::move(trt_engine), std::move(trt_parser), std::move(trt_logger));
}
} // namespace onnx_to_tensorrt
#endif // MXNET_USE_TENSORRT