blob: ce84f804ae78419a079e3701ee23ed3fdfdc80d8 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import itertools
import logging
from typing import Tuple
import numpy as np
import pytest
try:
# See issue #9362.
import torch
except:
pass
import tvm
import tvm.relay.testing
import tvm.testing
from tvm import relay
from tvm.contrib.download import download
from tvm.relay import Any, GlobalVar
from tvm.relay.expr_functor import ExprVisitor
from tvm.relay.op.contrib import tensorrt
SUPPORTED_DTYPES = ["float16", "float32"]
has_tensorrt_codegen = pytest.mark.skipif(
not tensorrt.is_tensorrt_compiler_enabled(), reason="TensorRT codegen not available"
)
# CAUTION: Currently always false in CI since adds tens of minutes to test time and depends
# on TensorRT installation. See https://github.com/apache/tvm/issues/11765
has_tensorrt_runtime = pytest.mark.skipif(
not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available"
)
run_module = tvm.testing.parameter(
pytest.param(False, marks=[has_tensorrt_codegen, *tvm.testing.requires_cuda.marks()]),
pytest.param(
True, marks=[has_tensorrt_runtime, has_tensorrt_codegen, *tvm.testing.requires_cuda.marks()]
),
ids=["compile", "run"],
)
def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray):
return [o.numpy()]
elif isinstance(o, tvm.runtime.container.ADT) or isinstance(o, list):
return [vmobj_to_list(f) for f in o]
else:
raise RuntimeError("Unknown object type: %s" % type(o))
def assert_result_dict_holds(result_dict, dtype="float16"):
for k1, k2 in itertools.combinations(result_dict, 2):
res1 = vmobj_to_list(result_dict[k1])
res2 = vmobj_to_list(result_dict[k2])
for r1, r2 in zip(res1, res2):
if dtype == "float16":
tvm.testing.assert_allclose(r1, r2, rtol=1e-1, atol=1e-1)
else:
tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=5e-3)
def set_outer_func_attr(func, compile_name, symbol_name):
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", compile_name)
func = func.with_attr("global_symbol", symbol_name)
return func
def set_inner_func_attr(func, pattern_name, composite_name):
func = func.with_attr("PartitionedFromPattern", pattern_name)
func = func.with_attr("Composite", composite_name)
return func
def run_and_verify_func(config, target="cuda", run_module=True, data_type="float32"):
"""Test a Relay func by compiling, running, and comparing TVM and TRT outputs.
Parameters
----------
config : Tuple[relay.Function, Dict[str, NDArray], List[str]]
A tuple containing 1) The function to test, 2) A dictionary of var names to input shapes and
3) A list of which vars should be considered params.
run_module: bool
If True, the built module will be run after being compiled.
data_type: str
Check between single and double floating precision
"""
np.random.seed(42)
f, input_shapes, is_param = config
params = {
x: np.random.uniform(-1, 1, input_shapes[x]).astype(dtype=data_type) for x in is_param
}
input_dict = {
k: np.random.uniform(-1, 1, v).astype(dtype=data_type)
for k, v in input_shapes.items()
if k not in is_param
}
dev = tvm.device(target)
result_dict = dict()
for mode in ["vm", "graph"]:
for use_trt in [True, False]:
mod = tvm.IRModule()
mod["main"] = f
result_key = mode + ("_trt" if use_trt else "")
if use_trt:
use_fp16 = data_type == "float16"
trt_target = tvm.target.Target(f"tensorrt -use_fp16={use_fp16}")
mod = relay.transform.InferType()(mod)
mod = tensorrt.partition_for_tensorrt(mod, params=params, target=trt_target)
with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
mode, mod=mod, device=dev, target=[target, trt_target]
).evaluate()
else:
mod = relay.transform.InferType()(mod)
with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
mode, mod=mod, device=dev, target=target
).evaluate()
if run_module:
result_dict[result_key] = func(**input_dict, **params)
if run_module:
assert_result_dict_holds(result_dict, data_type)
def test_tensorrt_simple(run_module):
for dtype in SUPPORTED_DTYPES:
xshape = (1, 3, 2, 2)
yshape = (1, 3, 1, 1)
zshape = (1, 1, 1, 1)
x = relay.var("x", shape=(xshape), dtype=dtype)
y = relay.var("y", shape=(yshape), dtype=dtype)
z = relay.var("z", shape=(zshape), dtype=dtype)
w = z * (x + y)
out = relay.nn.relu(w)
f = relay.Function([x, y, z], out)
x_data = np.random.uniform(-1, 1, xshape).astype(dtype)
y_data = np.random.uniform(-1, 1, yshape).astype(dtype)
z_data = np.random.uniform(-1, 1, zshape).astype(dtype)
result_dict = dict()
for mode in ["vm", "graph"]:
for use_trt in [False, True]:
mod = tvm.IRModule()
mod["main"] = f
result_key = mode + ("_trt" if use_trt else "")
if use_trt:
mod = relay.transform.InferType()(mod)
mod = tensorrt.partition_for_tensorrt(mod)
with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
mode, mod=mod, device=tvm.cuda(0), target="cuda"
).evaluate()
else:
mod = relay.transform.InferType()(mod)
with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
mode, mod=mod, device=tvm.cuda(0), target="cuda"
).evaluate()
if run_module:
result_dict[result_key] = func(x_data, y_data, z_data)
if run_module:
assert_result_dict_holds(result_dict)
def test_tensorrt_simple_cpu_io(run_module):
def get_graph():
dtype = "float32"
x_shape = (1, 3, 2, 2)
y_shape = (1, 3, 1, 1)
z_shape = (1, 1, 1, 1)
x = relay.var("x", shape=(x_shape), dtype=dtype)
y = relay.var("y", shape=(y_shape), dtype=dtype)
z = relay.var("z", shape=(z_shape), dtype=dtype)
w = z * (x + y)
out = relay.nn.relu(w)
f = relay.Function([x, y, z], out)
return f, {"x": x_shape, "y": y_shape, "z": z_shape}, ["y"]
run_and_verify_func(get_graph(), target="llvm", run_module=run_module)
def test_tensorrt_not_compatible(run_module):
dtype = "float32"
xshape = (1, 32, 14, 14)
x_data = np.random.uniform(-1, 1, xshape).astype(dtype)
x = relay.var("x", shape=(xshape), dtype=dtype)
y = relay.add(x, x)
z = relay.cast(relay.cast(y, "int32"), "float32")
out = relay.nn.relu(z)
f = relay.Function([x], out)
mod = tvm.IRModule()
mod["main"] = f
mod = tensorrt.partition_for_tensorrt(mod)
for mode in ["graph", "vm"]:
with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
mode, mod=mod, device=tvm.cuda(0), target="cuda"
).evaluate()
if run_module:
results = func(x_data)
def test_conv1d(run_module):
def get_graph(
x_shape=((1, 3, 224)),
k_shape=(10, 3, 3),
groups=1,
padding=(1, 1),
strides=(1),
dilation=(1),
channels=None,
d_type="float16",
):
x = relay.var("x", shape=(x_shape), dtype=d_type)
kernel = relay.var("kernel", shape=(k_shape), dtype=d_type)
out = relay.nn.conv1d(
x,
kernel,
kernel_size=k_shape[2:3],
groups=groups,
padding=padding,
strides=strides,
dilation=dilation,
channels=channels,
out_dtype="float16",
)
f = relay.Function([x, kernel], out)
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
for d_type in ["float16"]:
run_and_verify_func(
get_graph(channels=10, d_type=d_type), run_module=run_module, data_type=d_type
)
def test_conv2d(run_module):
def get_graph(
x_shape=(1, 32, 8, 8),
k_shape=(16, 32, 3, 3),
groups=1,
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
channels=None,
data_type="float16",
):
x = relay.var("x", shape=(x_shape), dtype=data_type)
kernel = relay.var("kernel", shape=(k_shape), dtype=data_type)
out = relay.nn.conv2d(
x,
kernel,
kernel_size=k_shape[2:4],
groups=groups,
padding=padding,
strides=strides,
dilation=dilation,
channels=channels,
out_dtype=data_type,
)
f = relay.Function([x, kernel], out)
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
for k_shape, groups in [((16, 32, 3, 3), 1), ((32, 1, 3, 3), 32)]:
for padding in [(0, 0), (1, 1)]:
for strides in [(1, 1), (2, 2)]:
for dilation in [(1, 1), (2, 2)]:
run_and_verify_func(
get_graph(
k_shape=k_shape,
groups=groups,
padding=padding,
strides=strides,
dilation=dilation,
),
run_module=run_module,
data_type="float16",
)
run_and_verify_func(
get_graph(
(1, 3, 16, 16), (3, 8, 7, 7), 3, [2, 2, 3, 3], [2, 2], [1, 1], 24, data_type="float16"
),
run_module=run_module,
data_type="float16",
)
run_and_verify_func(
get_graph((1, 3, 16, 16), (1, 3, 1, 1), channels=1, data_type="float32"),
run_module=run_module,
data_type="float32",
)
def test_conv2d_nhwc(run_module):
def get_graph(x_shape=(1, 8, 8, 32), k_shape=(3, 3, 32, 16)):
x = relay.var("x", shape=(x_shape), dtype="float32")
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
out = relay.nn.conv2d(
x, kernel, channels=16, kernel_size=(3, 3), data_layout="NHWC", kernel_layout="HWIO"
)
f = relay.Function([x, kernel], out)
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
run_and_verify_func(get_graph(), run_module=run_module)
def test_conv2d_weights_const(run_module):
def get_graph(
x_shape=(1, 32, 8, 8),
k_shape=(16, 32, 3, 3),
groups=1,
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
data_type="float16",
):
x = relay.var("x", shape=(x_shape), dtype=data_type)
kernel = relay.const(np.ones(k_shape).astype(dtype=data_type))
out = relay.nn.conv2d(
x,
kernel,
channels=k_shape[0],
kernel_size=k_shape[2:4],
groups=groups,
padding=padding,
strides=strides,
dilation=dilation,
)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
for tp in ["float16"]:
run_and_verify_func(get_graph(data_type=tp), run_module=run_module, data_type=tp)
def test_conv2d_weights_transposed(run_module):
def get_graph(x_shape=(1, 32, 9, 9), k_shape=(3, 3, 32, 16), order=(3, 2, 0, 1)):
x = relay.var("x", shape=(x_shape), dtype="float32")
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
kernel_t = relay.transpose(kernel, order)
# Conv2d requires constant weights in TensorRT, so the weights should be transposed by
# FoldConstant.
out = relay.nn.conv2d(x, kernel_t, channels=k_shape[order[0]], kernel_size=(3, 3))
f = relay.Function([x, kernel], out)
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
run_and_verify_func(get_graph(), run_module=run_module)
def test_dense(run_module):
def get_graph(x_shape=(1, 16), k_shape=(32, 16), dtp="float16"):
x = relay.var("x", shape=(x_shape), dtype=dtp)
kernel = relay.var("kernel", shape=(k_shape), dtype=dtp)
# Dense requires constant weights in TensorRT, so the weights are transposed by us.
out = relay.nn.dense(x, kernel, units=k_shape[0])
f = relay.Function([x, kernel], out)
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
for tp in ["float32"]:
run_and_verify_func(get_graph(dtp=tp), run_module=run_module, data_type=tp)
run_and_verify_func(get_graph(k_shape=(1, 16), dtp=tp), run_module=run_module, data_type=tp)
def test_batch_matmul(run_module):
def get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb=True):
x = relay.var("x", shape=(x_shape), dtype="float32")
y = relay.var("y", shape=(y_shape), dtype="float32")
out = relay.nn.batch_matmul(x, y, transpose_a=transa, transpose_b=transb)
f = relay.Function([x, y], out)
return f, {"x": x_shape, "y": y_shape}, []
run_and_verify_func(
get_graph(x_shape=(12, 64, 128), y_shape=(12, 128, 64), transa=True, transb=True),
run_module=run_module,
)
run_and_verify_func(
get_graph(x_shape=(12, 64, 128), y_shape=(12, 64, 128), transa=True, transb=False),
run_module=run_module,
)
run_and_verify_func(
get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb=True),
run_module=run_module,
)
run_and_verify_func(
get_graph(x_shape=(12, 128, 64), y_shape=(12, 64, 128), transa=False, transb=False),
run_module=run_module,
)
def test_bias_add(run_module):
def get_graph(x_shape=(1, 16), channels=16, axis=1):
x = relay.var("x", shape=(x_shape), dtype="float32")
bias = relay.var("bias", shape=(channels,), dtype="float32")
out = relay.nn.bias_add(x, bias, axis)
f = relay.Function([x, bias], out)
return f, {"x": x_shape, "bias": (channels,)}, ["bias"]
run_and_verify_func(get_graph(), run_module=run_module)
run_and_verify_func(get_graph((1, 6, 3, 4), 6), run_module=run_module)
run_and_verify_func(get_graph((1, 6, 3, 4), 4, -1), run_module=run_module)
def test_pool2d(run_module):
def get_graph(
op,
x_shape=(1, 3, 32, 32),
pool_size=(2, 2),
strides=(2, 2),
padding=(0, 0),
ceil_mode=False,
count_include_pad=None,
):
x = relay.var("x", shape=(x_shape), dtype="float32")
if count_include_pad is not None:
out = op(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)
else:
out = op(x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
for pool_size in [(2, 2), (3, 3)]:
for strides in [(1, 1), (2, 2)]:
for padding in [(0, 0), (1, 1), (0, 0, 1, 1)]:
for ceil_mode in [False, True]:
# Skip "the padding size is larger than or equal to the filter size for exclusive-counting pooling"
if pool_size == (2, 2) and padding == (0, 0, 1, 1):
continue
for count_include_pad in [False, True]:
# Skip "inclusive-counted blended or average pooling is not supported in combination with asymmetric padding"
if count_include_pad and (padding == (0, 0, 1, 1) or strides == (2, 2)):
continue
run_and_verify_func(
get_graph(
relay.nn.avg_pool2d,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
),
run_module=run_module,
)
run_and_verify_func(
get_graph(
relay.nn.max_pool2d,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
),
run_module=run_module,
)
def test_global_pool2d(run_module):
def get_graph(op, x_shape=(1, 3, 32, 32)):
x = relay.var("x", shape=(x_shape), dtype="float32")
out = op(x)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
run_and_verify_func(get_graph(relay.nn.global_max_pool2d), run_module=run_module)
run_and_verify_func(get_graph(relay.nn.global_avg_pool2d), run_module=run_module)
def test_batch_flatten(run_module):
def get_graph(x_shape=(1, 3, 4, 6), data_type="float16"):
x = relay.var("x", shape=(x_shape), dtype=data_type)
out = relay.nn.batch_flatten(x)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
for dtp in ["float16", "float32"]:
run_and_verify_func(get_graph(data_type=dtp), run_module=run_module, data_type=dtp)
def test_expand_dims(run_module):
def get_graph(x_shape=(1, 3), axis=1, num_newaxis=1):
x = relay.var("x", shape=(x_shape), dtype="float32")
out = relay.expand_dims(x, axis, num_newaxis)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
run_and_verify_func(get_graph(), run_module=run_module)
def test_squeeze(run_module):
def get_graph(x_shape, axis, dtype):
x = relay.var("x", shape=(x_shape), dtype=dtype)
out = relay.squeeze(x, axis=axis)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
for dtype in SUPPORTED_DTYPES:
run_and_verify_func(
get_graph((1, 5, 1, 1), (2, 3), dtype=dtype), run_module=run_module, data_type=dtype
)
run_and_verify_func(
get_graph((1, 3, 1), (-1,), dtype=dtype), run_module=run_module, data_type=dtype
)
def test_concatenate(run_module):
def get_graph(input_shapes, axis):
concat_inputs = []
shapes_dict = {}
for i in range(len(input_shapes)):
name = "input_{}".format(i)
concat_inputs.append(relay.var(name, shape=(input_shapes[i]), dtype="float32"))
shapes_dict[name] = input_shapes[i]
out = relay.concatenate(concat_inputs, axis)
f = relay.Function(concat_inputs, out)
return f, shapes_dict, []
run_and_verify_func(get_graph([(1, 2, 6, 6), (1, 3, 6, 6)], axis=1), run_module=run_module)
def test_split(run_module):
def get_graph(x_shape, indices_or_sections, axis):
x = relay.var("x", shape=(x_shape), dtype="float32")
out = relay.split(x, indices_or_sections=indices_or_sections, axis=axis)
f = relay.Function([x], out.astuple())
return f, {"x": x_shape}, []
run_and_verify_func(get_graph((1, 16), indices_or_sections=2, axis=1), run_module=run_module)
run_and_verify_func(get_graph((1, 16), indices_or_sections=4, axis=1), run_module=run_module)
run_and_verify_func(get_graph((1, 16), indices_or_sections=[8], axis=1), run_module=run_module)
run_and_verify_func(
get_graph((1, 16), indices_or_sections=[2, 3, 6, 10, 14], axis=1), run_module=run_module
)
def test_conv2d_transpose(run_module):
def get_graph(
x_shape=(1, 32, 8, 8), k_shape=(32, 16, 3, 3), groups=1, padding=(0, 0), strides=(1, 1)
):
x = relay.var("x", shape=(x_shape), dtype="float32")
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
out = relay.nn.conv2d_transpose(
x,
kernel,
channels=k_shape[1],
kernel_size=k_shape[2:4],
groups=groups,
padding=padding,
strides=strides,
)
f = relay.Function([x, kernel], out)
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
for padding in [(0, 0), (1, 1)]:
for strides in [(1, 1), (2, 2)]:
run_and_verify_func(get_graph(padding=padding, strides=strides), run_module=run_module)
def test_reshape(run_module):
def get_graph(x_shape, new_shape):
x = relay.var("x", shape=(x_shape), dtype="float16")
out = relay.reshape(x, new_shape)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
run_and_verify_func(
get_graph((1, 1, 1, 10), (-1, 10)), run_module=run_module, data_type="float16"
)
run_and_verify_func(
get_graph((1, 10, 2, 3), (1, -1)), run_module=run_module, data_type="float16"
)
run_and_verify_func(get_graph((1, 1, 2, 3), (1, 6)), run_module=run_module, data_type="float16")
class AreOpsOnGraph(ExprVisitor):
"""
Visits the Graph recursively and checks if it contains ops in the op_list
"""
def __init__(self, op_list):
ExprVisitor.__init__(self)
self.op_list = op_list
self.on_graph = False
def visit_call(self, call):
if isinstance(call.op, tvm.tir.op.Op):
if str(call.op.name) in self.op_list:
self.on_graph = True
return super().visit_call(call)
def are_ops_on_graph(self, subgraph) -> bool:
"""
This function recursively visits the graph and checks if op_list ops are ongraph"
"""
self.visit(subgraph)
return self.on_graph
def are_ops_on_trt(mod, op_list):
op_on_trt = False
op_on_tvm = False
for subgraph in mod.get_global_vars():
name = subgraph.name_hint
if mod[name].attrs and mod[name].attrs["Compiler"] == "tensorrt":
op_on_trt |= AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
else:
op_on_tvm |= AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
return op_on_trt and not op_on_tvm
def test_dynamic_reshape(run_module):
def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt):
result_arr = [{} for _ in range(len(x_data_list))]
for use_trt in [True, False]:
x = relay.var("x", shape=x_shape, dtype="float32")
out = relay.reshape(x, new_shape)
f = relay.Function([x], out)
mod = tvm.IRModule()
mod["main"] = f
if use_trt:
logging.info("Before partitioning:\n%s", mod)
mod = tensorrt.partition_for_tensorrt(mod)
logging.info("After partitioning:\n%s", mod)
assert are_ops_on_trt(mod, op_list=["reshape"]) == should_offload_to_trt
if run_module:
with relay.build_config(opt_level=3):
func = relay.create_executor(
"vm", mod=mod, device=tvm.cpu(0), target="llvm"
).evaluate()
for i, x_data in enumerate(x_data_list):
result_arr[i][use_trt] = func(x_data)
if run_module:
for i in range(len(x_data_list)):
assert_result_dict_holds(result_arr[i])
dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2]
x_shape = (relay.Any(), 3, 2, 3)
x_data_list = [
np.ones([dim_value] + list(x_shape)[1:]).astype("float32") for dim_value in dim_values
]
new_shape = (-1, 3, 2, 3)
should_offload_to_trt = True
test_run(x_data_list, x_shape, new_shape, should_offload_to_trt)
dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2]
x_shape = (relay.Any(), 3, 2, 3)
x_data_list = [
np.ones([dim_value] + list(x_shape)[1:]).astype("float32") for dim_value in dim_values
]
new_shape = (-1, 1, 2, 3)
should_offload_to_trt = False
test_run(x_data_list, x_shape, new_shape, should_offload_to_trt)
dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2]
x_shape = (1, relay.Any(), 2, 3)
x_data_list = [
np.ones(list(x_shape[:1]) + [dim_value] + list(x_shape)[2:]).astype("float32")
for dim_value in dim_values
]
new_shape = (1, -1, 2, 3)
should_offload_to_trt = False
test_run(x_data_list, x_shape, new_shape, should_offload_to_trt)
def test_transpose(run_module):
def get_graph(x_shape, order):
x = relay.var("x", shape=(x_shape), dtype="float32")
out = relay.transpose(x, order)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
run_and_verify_func(get_graph((1, 16, 7, 7), [0, 2, 3, 1]), run_module=run_module)
run_and_verify_func(get_graph((1, 7, 7, 16), [0, 3, 1, 2]), run_module=run_module)
def test_float_const(run_module):
def get_graph(x_shape=(1, 16)):
x = relay.var("x", shape=(x_shape), dtype="float32")
beta = relay.const(1, dtype="float32")
out = relay.multiply(x, beta)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
run_and_verify_func(get_graph(), run_module=run_module, data_type="float32")
def test_float_const16(run_module):
def get_graph(x_shape=(1, 16)):
x = relay.var("x", shape=(x_shape), dtype="float16")
beta = relay.const(1, dtype="float16")
out = relay.multiply(x, beta)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
run_and_verify_func(get_graph(), run_module=run_module, data_type="float16")
def test_pad(run_module):
def get_graph(x_shape, pad_width):
x = relay.var("x", shape=(x_shape), dtype="float32")
out = relay.nn.pad(x, pad_width=pad_width)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
run_and_verify_func(
get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0]]), run_module=run_module
)
run_and_verify_func(
get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [1, 1], [1, 1]]), run_module=run_module
)
run_and_verify_func(
get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 1], [2, 0]]), run_module=run_module
)
run_and_verify_func(
get_graph((1, 8, 3, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]),
run_module=run_module,
)
def test_add(run_module):
def get_graph(x_shape):
x = relay.var("x", shape=(x_shape), dtype="float16")
y = relay.var("y", shape=(x_shape), dtype="float16")
out = relay.add(x, y)
f = relay.Function([x, y], out)
return f, {"x": x_shape, "y": x_shape}, []
run_and_verify_func(get_graph((1, 1000)), run_module=run_module, data_type="float16")
def test_softmax(run_module):
def get_graph(x_shape, axis, data_type="float32"):
x = relay.var("x", shape=(x_shape), dtype=data_type)
out = relay.nn.softmax(x, axis=axis)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
run_and_verify_func(
get_graph((1, 1000), axis=1, data_type="float32"),
run_module=run_module,
data_type="float32",
)
run_and_verify_func(
get_graph((1, 1000), axis=-1, data_type="float32"),
run_module=run_module,
data_type="float32",
)
run_and_verify_func(
get_graph((1, 3, 4), axis=-2, data_type="float16"),
run_module=run_module,
data_type="float16",
)
run_and_verify_func(
get_graph((1, 3, 4), axis=1, data_type="float16"),
run_module=run_module,
data_type="float16",
)
def test_batch_norm(run_module):
def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5):
x = relay.var("x", shape=(x_shape), dtype="float32")
beta = relay.var("beta", shape=(param_shape), dtype="float32")
gamma = relay.var("gamma", shape=(param_shape), dtype="float32")
moving_mean = relay.var("moving_mean", shape=(param_shape), dtype="float32")
moving_var = relay.var("moving_var", shape=(param_shape), dtype="float32")
out, _, _ = relay.nn.batch_norm(
x,
gamma=gamma,
beta=beta,
moving_mean=moving_mean,
moving_var=moving_var,
axis=axis,
center=True,
scale=True,
epsilon=epsilon,
)
f = relay.Function([x, gamma, beta, moving_mean, moving_var], out)
return (
f,
{
"x": x_shape,
"beta": param_shape,
"gamma": param_shape,
"moving_mean": param_shape,
"moving_var": param_shape,
},
["beta", "gamma", "moving_mean", "moving_var"],
)
run_and_verify_func(get_graph((1, 64, 56, 56), (64,)), run_module=run_module)
run_and_verify_func(
get_graph((1, 56, 56, 64), (64,), axis=3, epsilon=1.001e-05), run_module=run_module
)
run_and_verify_func(get_graph((1, 4, 8, 4), (8,), axis=2), run_module=run_module)
run_and_verify_func(get_graph((1, 8, 4, 4, 4), (8,), axis=1), run_module=run_module)
run_and_verify_func(get_graph((1, 4, 8, 4, 4), (8,), axis=2), run_module=run_module)
run_and_verify_func(get_graph((1, 4, 4, 4, 8), (8,), axis=4), run_module=run_module)
run_and_verify_func(get_graph((1, 8), (8,), axis=1), run_module=run_module)
run_and_verify_func(get_graph((1, 3, 8), (8,), axis=2), run_module=run_module)
def test_layer_norm(run_module):
def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5):
x = relay.var("x", shape=(x_shape), dtype="float32")
gamma = relay.var("gamma", shape=(param_shape), dtype="float32")
beta = relay.var("beta", shape=(param_shape), dtype="float32")
out = relay.nn.layer_norm(
x, gamma=gamma, beta=beta, axis=axis, epsilon=epsilon, center=True, scale=True
)
f = relay.Function([x, gamma, beta], out)
return (f, {"x": x_shape, "beta": param_shape, "gamma": param_shape}, ["beta", "gamma"])
run_and_verify_func(get_graph((1, 32, 8, 8), (32,)), run_module=run_module)
run_and_verify_func(
get_graph((1, 8, 8, 32), (32,), axis=3, epsilon=1.001e-05), run_module=run_module
)
run_and_verify_func(get_graph((1, 8), (8,), axis=1), run_module=run_module)
def test_unary(run_module):
def get_graph(op, x_shape=(1, 8, 3, 3)):
x = relay.var("x", shape=(x_shape), dtype="float32")
out = op(x)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
for op in [
relay.nn.relu,
relay.sigmoid,
relay.tanh,
relay.exp,
relay.log,
relay.sqrt,
relay.abs,
relay.negative,
relay.sin,
relay.cos,
relay.atan,
relay.ceil,
relay.floor,
relay.erf,
]:
run_and_verify_func(get_graph(op), run_module=run_module)
def test_clip(run_module):
def get_graph(x_shape=(1, 8, 3, 3)):
x = relay.var("x", shape=(x_shape), dtype="float16")
out = relay.clip(x, a_min=-0.2, a_max=0.4)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
run_and_verify_func(get_graph(), run_module=run_module, data_type="float16")
def test_relu(run_module):
def get_graph(x_shape=(1, 8, 3, 4)):
x = relay.var("x", shape=(x_shape), dtype="float16")
out = relay.nn.relu(x)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
run_and_verify_func(get_graph(), run_module=run_module, data_type="float16")
def test_leaky_relu(run_module):
def get_graph(x_shape=(1, 8, 3, 4)):
x = relay.var("x", shape=(x_shape), dtype="float16")
out = relay.nn.leaky_relu(x, alpha=0.1)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
run_and_verify_func(get_graph(), run_module=run_module, data_type="float16")
def test_binary(run_module):
def get_graph(op, x_shape, y_shape, y_is_const=False, d_type="float16"):
x = relay.var("x", shape=(x_shape), dtype=d_type)
if y_is_const:
y = relay.const(np.ones(y_shape).astype(d_type))
out = op(x, y)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
y = relay.var("y", shape=(y_shape), dtype=d_type)
out = op(x, y)
f = relay.Function([x, y], out)
return f, {"x": x_shape, "y": y_shape}, []
for op in [relay.add, relay.subtract, relay.multiply, relay.divide, relay.power]:
for d_type in SUPPORTED_DTYPES:
for y_is_const in [True, False]:
run_and_verify_func(
get_graph(op, (1, 8, 3, 3), (1, 8, 3, 3), y_is_const, d_type),
run_module=run_module,
data_type=d_type,
)
run_and_verify_func(
get_graph(op, (1, 8, 1, 3), (1, 8, 3, 1), y_is_const, d_type),
run_module=run_module,
data_type=d_type,
)
run_and_verify_func(
get_graph(op, (1, 10), (10,), y_is_const, d_type),
run_module=run_module,
data_type=d_type,
)
run_and_verify_func(
get_graph(op, (1, 1, 1, 10), (10,), y_is_const, d_type),
run_module=run_module,
data_type=d_type,
)
run_and_verify_func(
get_graph(op, (1, 1, 1), (3,), y_is_const, d_type),
run_module=run_module,
data_type=d_type,
)
def test_reduce(run_module):
def get_graph(op, x_shape=(1, 2, 3, 4), axis=(2, 3), keepdims=False, d_type="float32"):
x = relay.var("x", shape=(x_shape), dtype=d_type)
out = op(x, axis=axis, keepdims=keepdims)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
for type in SUPPORTED_DTYPES:
for op in [relay.sum, relay.prod, relay.max, relay.min, relay.mean]:
for keepdims in [True, False]:
run_and_verify_func(
get_graph(op, axis=(1), keepdims=keepdims, d_type=type),
run_module=run_module,
data_type=type,
)
run_and_verify_func(
get_graph(op, axis=(2, 3), keepdims=keepdims, d_type=type),
run_module=run_module,
data_type=type,
)
run_and_verify_func(
get_graph(op, axis=(1, 2), keepdims=keepdims, d_type=type),
run_module=run_module,
data_type=type,
)
run_and_verify_func(
get_graph(op, axis=(1, 2, 3), keepdims=keepdims, d_type=type),
run_module=run_module,
data_type=type,
)
def test_strided_slice(run_module):
def get_graph(x_shape, begin, end, strides=None, slice_mode="size"):
x = relay.var("x", shape=(x_shape), dtype="float32")
if strides:
out = relay.strided_slice(x, begin, end, strides, slice_mode=slice_mode)
else:
out = relay.strided_slice(x, begin, end, slice_mode=slice_mode)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
for slice_mode in ["size", "end"]:
run_and_verify_func(
get_graph((1, 3, 6, 7), (0, 0, 0, 0), (1, 1, 6, 7), slice_mode=slice_mode),
run_module=run_module,
)
run_and_verify_func(
get_graph((1, 3, 6, 7), [0, 1, 0, 0], [1, 2, 6, 6], slice_mode=slice_mode),
run_module=run_module,
)
run_and_verify_func(
get_graph((2, 3, 6, 7), [0, 0, 0, 0], [-1, -1, -1, -1], slice_mode=slice_mode),
run_module=run_module,
)
run_and_verify_func(
get_graph((2, 3, 6, 7), [0, 1, 0, 0], [-1, -1, -1, -1], slice_mode=slice_mode),
run_module=run_module,
)
run_and_verify_func(
get_graph((1, 6), [0, 1], [1, 3], slice_mode=slice_mode), run_module=run_module
)
def test_adaptive_pool2d(run_module):
def get_graph(op, x_shape=(1, 3, 32, 32), out_size=(1, 1), data_type="float16"):
x = relay.var("x", shape=(x_shape), dtype=data_type)
out = op(x, out_size)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
for type in SUPPORTED_DTYPES:
run_and_verify_func(
get_graph(relay.nn.adaptive_max_pool2d, data_type=type),
run_module=run_module,
data_type=type,
)
run_and_verify_func(
get_graph(relay.nn.adaptive_avg_pool2d, data_type=type),
run_module=run_module,
data_type=type,
)
def test_multiple_outputs(run_module):
def get_graph(d_type="float16"):
x = relay.var("x", shape=(1, 3), dtype=d_type)
y = relay.var("y", shape=(1, 3), dtype=d_type)
z = relay.add(x, y)
w = relay.add(z, y)
out = relay.Tuple((z, w))
f = relay.Function([x, y], out)
return f, {"x": (1, 3), "y": (1, 3)}, []
for type in SUPPORTED_DTYPES:
run_and_verify_func(get_graph(d_type=type), run_module=run_module, data_type=type)
@pytest.mark.skip(reason=("Fails assert_allclose. See https://github.com/apache/tvm/issues/11765"))
def test_conv3d(run_module):
def get_graph(
x_shape=(1, 24, 8, 8, 8),
k_shape=(16, 24, 3, 3, 3),
groups=1,
padding=(0, 0, 0),
strides=(1, 1, 1),
dilation=(1, 1, 1),
):
x = relay.var("x", shape=(x_shape), dtype="float32")
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
out = relay.nn.conv3d(
x,
kernel,
channels=k_shape[0],
kernel_size=k_shape[2:],
groups=groups,
padding=padding,
strides=strides,
dilation=dilation,
)
f = relay.Function([x, kernel], out)
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
run_and_verify_func(get_graph(), run_module=run_module)
run_and_verify_func(get_graph(padding=(0, 0, 0, 1, 1, 1)), run_module=run_module)
def test_pool3d(run_module):
def get_graph(
op,
x_shape=(1, 3, 8, 32, 32),
pool_size=(2, 2, 2),
strides=(2, 2, 2),
padding=(0, 0, 0),
ceil_mode=False,
count_include_pad=None,
):
x = relay.var("x", shape=(x_shape), dtype="float32")
if count_include_pad is not None:
out = op(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)
else:
out = op(x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode)
f = relay.Function([x], out)
return f, {"x": x_shape}, []
run_and_verify_func(get_graph(relay.nn.avg_pool3d), run_module=run_module)
run_and_verify_func(get_graph(relay.nn.max_pool3d), run_module=run_module)
run_and_verify_func(
get_graph(relay.nn.max_pool3d, padding=(0, 0, 0, 1, 1, 1)), run_module=run_module
)
run_and_verify_func(get_graph(relay.nn.max_pool3d, strides=(1, 1, 1)), run_module=run_module)
def test_conv3d_transpose(run_module):
def get_graph(
x_shape=(1, 32, 8, 8, 8),
k_shape=(32, 16, 3, 3, 3),
groups=1,
padding=(0, 0, 0),
strides=(1, 1, 1),
output_padding=(0, 0, 0),
):
x = relay.var("x", shape=(x_shape), dtype="float32")
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
out = relay.nn.conv3d_transpose(
x,
kernel,
channels=k_shape[1],
kernel_size=k_shape[2:5],
groups=groups,
padding=padding,
strides=strides,
output_padding=output_padding,
)
f = relay.Function([x, kernel], out)
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
run_and_verify_func(get_graph(), run_module=run_module)
run_and_verify_func(get_graph(strides=(2, 2, 2)), run_module=run_module)
run_and_verify_func(
get_graph(strides=(2, 2, 2), output_padding=(1, 1, 1)), run_module=run_module
)
@has_tensorrt_codegen
def test_dynamic_offload():
"""
This test checks for proper dynamic offloading of relay graphs. An addition between
the outputs of two conv2d's is performed, one of them having all static args whereas
the other has a arg with dynamic shape. It is expected for the TRT partitioner to
offload the conv2d with dynamic arg to TVM while running the other in TRT.
"""
data_shape = (1, 32, 8, 8)
k_shape = (1, 32, 3, 3)
x = relay.var("x", shape=(data_shape[0], data_shape[1], Any(), Any()), dtype="float32")
y = relay.var("y", shape=(data_shape), dtype="float32")
kernel = relay.const(np.random.rand(*k_shape).astype("float32"))
def get_expected():
# Create a nested TRT function that matches the expected output
mod = tvm.IRModule()
outer_var = relay.var("tensorrt_0_i0", shape=(data_shape), dtype="float32")
inner_var = relay.var("FunctionVar_0_0", shape=(data_shape), dtype="float32")
inner_body = relay.nn.conv2d(
inner_var, kernel, channels=k_shape[0], kernel_size=k_shape[2:4]
)
inner_func = relay.Function([inner_var], inner_body)
inner_func = set_inner_func_attr(inner_func, "nn.conv2d_", "tensorrt.nn.conv2d")
outer_body = inner_func(outer_var)
outer_func = relay.Function([outer_var], outer_body)
outer_func = set_outer_func_attr(outer_func, "tensorrt", "tvmgen_default_tensorrt_main_0")
gv = GlobalVar("tvmgen_default_tensorrt_main_0")
mod[gv] = outer_func
mod = relay.transform.InferType()(mod)
# Create the main function
out1 = relay.nn.conv2d(x, kernel, channels=k_shape[0], kernel_size=k_shape[2:4])
out = relay.add(out1, gv(y))
f = relay.Function([x, y], out)
mod["main"] = f
mod = relay.transform.InferType()(mod)
return mod
# Create relay function that will be offloaded to TRT
out1 = relay.nn.conv2d(x, kernel, channels=k_shape[0], kernel_size=k_shape[2:4])
out2 = relay.nn.conv2d(y, kernel, channels=k_shape[0], kernel_size=k_shape[2:4])
out = relay.add(out1, out2)
f = relay.Function([x, y], out)
# Pass the function to TRT compilation
mod = tvm.IRModule()
mod["main"] = f
mod = relay.transform.InferType()(mod)
mod_trt = tensorrt.partition_for_tensorrt(mod)
# Get the expected relay graph and compare
mod_exp = get_expected()
tvm.ir.assert_structural_equal(mod_trt, mod_exp, map_free_vars=True)
def test_tensorrt_dynamic_batch(run_module):
batches_to_test = [1, 1, 0, 2, 3, 0, 1, 3, 2]
x_shape = (relay.Any(), 1, 8, 8)
x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32")
result_arr = [{} for _ in range(len(batches_to_test))]
for use_trt in [True, False]:
x = relay.var("x", shape=x_shape, dtype="float32")
out = relay.nn.relu(x)
f = relay.Function([x], out)
mod = tvm.IRModule()
mod["main"] = f
if use_trt:
mod = tensorrt.partition_for_tensorrt(mod)
if run_module:
with relay.build_config(opt_level=3):
func = relay.create_executor(
"vm", mod=mod, device=tvm.cpu(0), target="llvm"
).evaluate()
for i, batch_size in enumerate(batches_to_test):
result_arr[i][use_trt] = func(x_data[:batch_size, ...])
if run_module:
for i in range(len(batches_to_test)):
assert_result_dict_holds(result_arr[i])
def test_tensorrt_dynamic_batch_conv(run_module):
batches_to_test = [1, 5, 1, 0, 2, 3, 0, 1, 3, 2]
x_shape = (relay.Any(), 32, 8, 8)
x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32")
k_shape = (16, 32, 3, 3)
params = {"kernel": np.random.uniform(-1, 1, k_shape).astype("float32")}
for use_implicit_batch in [True, False]:
result_arr = [{"cuda": {}, "llvm": {}} for _ in range(len(batches_to_test))]
for use_trt in [True, False]:
x = relay.var("x", shape=x_shape, dtype="float32")
kernel = relay.var("kernel", shape=k_shape, dtype="float32")
out = relay.nn.conv2d(x, kernel, channels=16, kernel_size=(3, 3), groups=1)
f = relay.Function([x, kernel], out)
mod = tvm.IRModule()
mod["main"] = f
trt_target = tvm.target.Target(f"tensorrt -use_implicit_batch={use_implicit_batch}")
if use_trt:
mod = tensorrt.partition_for_tensorrt(mod, params=params, target=trt_target)
if run_module:
for target in ["llvm", "cuda"]:
targets = [target]
if use_trt:
targets.append(trt_target)
with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
"vm", mod=mod, device=tvm.device(target), target=targets
).evaluate()
for i, batch_size in enumerate(batches_to_test):
result_arr[i][target][use_trt] = func(x_data[:batch_size, ...], **params)
if run_module:
for i in range(len(batches_to_test)):
for target in ["llvm", "cuda"]:
assert_result_dict_holds(result_arr[i][target])
def test_maskrcnn_resnet50(run_module) -> None:
"""
This function tests the working of pytorch maskrcnn with resnet50 as backbone with
VM and VM + TRT. Since the order of compiled model outputs is a bit different from
original pytorch model, it uses a custom logic for comparison check.
"""
import torch
import torchvision
def convert_traced_model_to_vm_trt(
traced_module: torch.jit.TopLevelTracedModule, np_sample_input: np.ndarray, target: str
) -> tvm.runtime.vm.Executable:
"""
This function converts a traced pytorch model to VM + TRT.
"""
input_shape = np_sample_input.shape
input_name = "input0"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(traced_module, shape_list)
trt_target = tvm.target.Target("tensorrt -remove_no_mac_subgraphs=True")
mod = tensorrt.partition_for_tensorrt(mod, params=params, target=trt_target)
targets = [target, trt_target]
with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
vm_trt_exec = relay.vm.compile(mod, target=targets, params=params)
return vm_trt_exec
class TraceWrapper(torch.nn.Module):
"""
This class is a wrapper over the torch module to convert the outputs into traceable form
"""
def __init__(self, model: torch.nn.Module) -> None:
super().__init__()
self.model = model
def forward(
self, inp: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
out = self.model(inp)
return out[0]["boxes"], out[0]["scores"], out[0]["labels"], out[0]["masks"]
def get_traced_maskrcnn_model(np_sample_input: np.ndarray) -> torch.jit.TopLevelTracedModule:
"""
This function takes a sample input and returns the traced maskrcnn model
"""
model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
model = TraceWrapper(model_func(pretrained=True))
model.eval()
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=np_sample_input.shape))
with torch.no_grad():
out = model(inp)
traced_module = torch.jit.trace(model, inp)
traced_module.eval()
return traced_module
def get_maskrcnn_input(in_size: int) -> np.ndarray:
"""
This function gets a real image with multiple objects of interest and returns it.
"""
input_shape = (1, 3, in_size, in_size)
img_path = "test_street_small.jpg"
img_url = "https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/detection/street_small.jpg"
download(img_url, img_path)
import cv2
img = cv2.imread(img_path).astype("float32")
img = cv2.resize(img, (in_size, in_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img / 255.0, [2, 0, 1])
img = np.expand_dims(img, axis=0)
return img
in_size = 300
np_sample_input = get_maskrcnn_input(in_size)
traced_module = get_traced_maskrcnn_model(np_sample_input)
vm_trt_exec = convert_traced_model_to_vm_trt(traced_module, np_sample_input, target="llvm")
if run_module:
dev = tvm.cpu()
vm = tvm.runtime.vm.VirtualMachine(vm_trt_exec, dev)
vm.set_input("main", **{"input0": np_sample_input})
tvm_res = vm.run()
# Descending sort by scores and get the high confidence indices. In this example 9 is chosen,
# because this image has 9 boxes over 0.9 confidence
num_high_confidence_boxes = 9
tvm_indices = np.argsort(-1 * tvm_res[1].numpy())[:num_high_confidence_boxes]
with torch.no_grad():
out = traced_module(torch.Tensor(np_sample_input))
# Descending sort by scores and get the high confidence indices
pt_indices = np.argsort(-1 * out[1].numpy())[:num_high_confidence_boxes]
# [Box Tol, Score Tol, Label Tol, Mask Tol]
tol = [1e-1, 5e-3, 1e-5, 4e-1]
# Because of certain ops, there are certain minor differences in TVM outputs and PT outputs,
# This means that the tolerance can't be 1e-4 or 1e-5 throughout. The ideal way to get around
# this is to test it on an entire dataset and compare mAP with the original model.
# However, since that is not practically possible on CI, the following compromise is made.
# These tolerances are chosen based on their impact or lack thereof to the mAP score, e.g:
# 0.1 pixel difference of a box in a 300X300 image wont make any change.
for i, tol_val in zip(range(4), tol):
np.testing.assert_allclose(
tvm_res[i].numpy()[tvm_indices],
out[i].numpy()[pt_indices],
rtol=tol_val,
atol=tol_val,
)
def test_empty_subgraph(run_module):
x_shape = (1, 3, 5)
mod = tvm.IRModule()
# Empty tensorrt subgraph.
var1 = relay.var("tensorrt_0_i0", shape=(x_shape), dtype="float32")
f1 = GlobalVar("tensorrt_0")
func = relay.Function([var1], var1)
func = set_outer_func_attr(func, "tensorrt", "tvmgen_default_tensorrt_0")
mod[f1] = func
mod = relay.transform.InferType()(mod)
# Create the main function
x = relay.var("x", shape=x_shape, dtype="float32")
out = f1(relay.nn.relu(x))
f = relay.Function([x], out)
mod["main"] = f
x_data = np.random.uniform(-1, 1, x_shape).astype("float32")
for mode in ["graph", "vm"]:
with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
mode, mod=mod, device=tvm.cuda(0), target="cuda"
).evaluate()
if run_module:
results = func(x_data)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
tvm.testing.main()