blob: 36a7f33fe6a04ba527f5855417dd37e1ca525388 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from mxnet.gluon.model_zoo import vision
from mxnet.test_utils import assert_almost_equal
import mxnet as mx
import numpy as np
import os
batch_shape = (1, 3, 224, 224)
url = 'https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true'
model_file_name = 'resnet18_v2_trt_test'
def get_image(image_url):
fname = mx.test_utils.download(image_url, fname=image_url.split('/')[-1].split('?')[0])
img = mx.image.imread(fname)
img = mx.image.imresize(img, 224, 224) # Resize
img = img.transpose((2, 0, 1)) # Channel first
img = img.expand_dims(axis=0) # Batchify
img = mx.nd.cast(img, dtype=np.float32)
return img / 255.0
def test_tensorrt_resnet18_feature_vect():
print("downloading sample input")
input_data = get_image(url)
gluon_resnet18 = vision.resnet18_v2(pretrained=True)
gluon_resnet18.hybridize()
gluon_resnet18.forward(input_data)
gluon_resnet18.export(model_file_name)
sym, arg_params, aux_params = mx.model.load_checkpoint(model_file_name, 0)
executor = sym.simple_bind(ctx=mx.gpu(), data=batch_shape,
grad_req='null', force_rebind=True)
executor.copy_params_from(arg_params, aux_params)
y = executor.forward(is_train=False, data=input_data)
trt_sym = sym.get_backend_symbol('TensorRT')
mx.contrib.tensorrt.init_tensorrt_params(trt_sym, arg_params, aux_params)
original_precision_value = mx.contrib.tensorrt.get_use_fp16()
try:
mx.contrib.tensorrt.set_use_fp16(True)
executor = trt_sym.simple_bind(ctx=mx.gpu(), data=batch_shape,
grad_req='null', force_rebind=True)
executor.copy_params_from(arg_params, aux_params)
y_trt = executor.forward(is_train=False, data=input_data)
mx.contrib.tensorrt.set_use_fp16(False)
executor = trt_sym.simple_bind(ctx=mx.gpu(), data=batch_shape,
grad_req='null', force_rebind=True)
executor.copy_params_from(arg_params, aux_params)
y_trt_fp32 = executor.forward(is_train=False, data=input_data)
no_trt_output = y[0].asnumpy()[0]
trt_output = y_trt[0].asnumpy()[0]
trt_fp32_output = y_trt_fp32[0].asnumpy()[0]
assert_almost_equal(no_trt_output, trt_output, 1e-1, 1e-2)
assert_almost_equal(no_trt_output, trt_fp32_output, 1e-4, 1e-4)
finally:
mx.contrib.tensorrt.set_use_fp16(original_precision_value)
if __name__ == '__main__':
import nose
nose.runmodule()