blob: fdac859aef6f4fd005e9a10c489ae7a61f399abc [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import gluoncv
import mxnet as mx
from tests.python.unittest.common import assertRaises
def test_training_without_trt():
run_resnet(is_train=True, use_tensorrt=False)
def test_inference_without_trt():
run_resnet(is_train=False, use_tensorrt=False)
def test_training_with_trt():
assertRaises(RuntimeError, run_resnet, is_train=True, use_tensorrt=True)
def test_inference_with_trt():
run_resnet(is_train=False, use_tensorrt=True)
def run_resnet(is_train, use_tensorrt):
original_trt_value = mx.contrib.tensorrt.get_use_tensorrt()
try:
mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt)
ctx = mx.gpu(0)
batch_size = 1
h = 32
w = 32
model_name = 'cifar_resnet20_v1'
resnet = gluoncv.model_zoo.get_model(model_name, pretrained=True)
data = mx.sym.var('data')
out = resnet(data)
softmax = mx.sym.SoftmaxOutput(out, name='softmax')
if is_train:
grad_req = 'write'
else:
grad_req = 'null'
if use_tensorrt:
all_params = dict([(k, v.data()) for k, v in resnet.collect_params().items()])
mx.contrib.tensorrt.tensorrt_bind(softmax, ctx=ctx, all_params=all_params,
data=(batch_size, 3, h, w), softmax_label=(batch_size,),
force_rebind=True, grad_req=grad_req)
else:
softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), softmax_label=(batch_size,),
force_rebind=True, grad_req=grad_req)
finally:
mx.contrib.tensorrt.set_use_tensorrt(original_trt_value)
if __name__ == '__main__':
import nose
nose.runmodule()