blob: 6e78a67fd0859d92b88123c6fac661211b1a6aea [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import os
import tvm
import tvm.testing
from tvm import relax, tir
import numpy as np
from tvm.script import relax as R, ir as I, tir as T
from tvm.relax.testing import transform
import tempfile
from tvm.relax.transform.tuning_api import Trace
from tvm.relax.dpl import is_op, wildcard
env_checker_codegen = tvm.get_global_func("relax.ext.tensorrt", True)
env_checker_runtime = tvm.get_global_func("relax.is_tensorrt_runtime_enabled", True)
requires_tensorrt_codegen = pytest.mark.skipif(
not env_checker_codegen,
reason="TensorRT codegen not available",
requires_tensorrt_runtime = pytest.mark.skipif(
not env_checker_runtime or not env_checker_runtime(),
reason="TensorRT runtime not available",
# Global variable in pytest that applies markers to all tests.
pytestmark = [requires_tensorrt_codegen] + tvm.testing.requires_cuda.marks()
# Target gpu
target_str = "nvidia/nvidia-t4"
target =
dev = tvm.cuda()
def check_executable(exec, dev, inputs, expected, entry_func_name):
vm = relax.VirtualMachine(exec, dev)
out = vm[entry_func_name](*inputs)
tvm.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-5, rtol=1e-5)
def check_roundtrip(exec0, dev, inputs, expected, entry_func_name="main"):
exec1 = tvm.runtime.load_module("")
assert exec0.stats() == exec1["stats"]()
assert exec0.as_text() == exec1["as_text"]()
check_executable(exec0, dev, inputs, expected, entry_func_name)
check_executable(exec1, dev, inputs, expected, entry_func_name)
def gen_ground_truth(mod, target, dev, inputs):
# Lower and run tuning
# Since there is no default schedule for GPU in MS yet, this is necessary
with target:
seq = tvm.transform.Sequential(
[relax.transform.LegalizeOps(), tir.transform.DefaultGPUSchedule()]
new_mod = seq(mod)
assert relax.analysis.well_formed(new_mod)
exec =, target, params={})
vm = relax.VirtualMachine(exec, dev)
return vm["main"](*inputs)
class InputModule:
def main(
x: R.Tensor((16, 16), "float32"), y: R.Tensor((16, 16), "float32")
) -> R.Tensor((16, 16), "float32"):
with R.dataflow():
z1 = R.multiply(x, y)
z2 = R.add(z1, x)
z3 = R.add(z1, z2)
z4 = R.multiply(z3, z2)
z5 = R.add(z4, z1)
return z5
def setup_test():
# Prepare IRModule and its input
mod = InputModule
assert isinstance(mod, tvm.IRModule)
np0 = np.random.rand(16, 16).astype(np.float32)
np1 = np.random.rand(16, 16).astype(np.float32)
data0 = tvm.nd.array(np0, dev)
data1 = tvm.nd.array(np1, dev)
inputs = [data0, data1]
# Ground truth should be generated before annotation
# due to the conflict with MS task extraction
# TODO(@sunggg): Sort this out
expected = gen_ground_truth(mod, target, dev, inputs)
return mod, inputs, expected
entry_func_name = tvm.testing.parameter("main", "func")
def test_tensorrt_only(entry_func_name):
mod, inputs, expected = setup_test()
if entry_func_name != "main":
mod[entry_func_name] = mod
del mod["main"]
# Define patterns that we want to offload to byoc
# This test will offload entire model
# Thus, define patterns for both `multiply` and `add` ops
patterns = [
("tensorrt.multiply", is_op("relax.multiply")(wildcard(), wildcard())),
("tensorrt.add", is_op("relax.add")(wildcard(), wildcard())),
new_mod = tvm.transform.Sequential(
ex0 =, target, params={})
# Sanity check for the correctness and roundtrip
check_roundtrip(ex0, dev, inputs, expected, entry_func_name)
def test_mix_use_tensorrt_and_tvm():
mod, inputs, expected = setup_test()
# Define patterns that we want to offload to byoc
# This test will only offload `add` op to tensorrt
# and tune `multiply` op with MetaSchedule
patterns = [
("tensorrt.add", is_op("relax.add")(wildcard(), wildcard())),
# Run Codegen pass
with tempfile.TemporaryDirectory() as work_dir:
with target, tvm.transform.PassContext(trace=Trace(mod), opt_level=0):
new_mod = tvm.transform.Sequential(
params={}, work_dir=work_dir, max_trials_global=8
assert relax.analysis.well_formed(new_mod)
with transform.PassContext(opt_level=0):
ex0 =, target, params={})
# Sanity check for the correctness and roundtrip
check_roundtrip(ex0, dev, inputs, expected)
class Conv2dx2:
def main(
data: R.Tensor((16, 32, 32, 16), dtype="float16"),
weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
cls = Conv2dx2
with R.dataflow():
lv: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.fused_relax_nn_conv2d_tensorrt(
data, weight1
gv: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.fused_relax_nn_conv2d_tensorrt(
lv, weight2
return gv
def fused_relax_nn_conv2d_tensorrt(
data: R.Tensor((16, 32, 32, 16), dtype="float16"),
weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
R.func_attr({"Codegen": "tensorrt", "global_symbol": "fused_relax_nn_conv2d_tensorrt"})
def gv(
data_1: R.Tensor((16, 32, 32, 16), dtype="float16"),
weight1_1: R.Tensor((16, 3, 3, 16), dtype="float16"),
) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
R.func_attr({"Composite": "tensorrt.conv2d", "Primitive": 1})
with R.dataflow():
gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = R.nn.conv2d(
padding=[1, 1, 1, 1],
return gv_1
gv1: R.Tensor((16, 32, 32, 16), dtype="float16") = gv(data, weight1)
return gv1
class Conv2dx2_after:
def main(
data: R.Tensor((16, 32, 32, 16), dtype="float16"),
weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
with R.dataflow():
lv = R.call_dps_packed(
(data, weight1),
out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"),
gv = R.call_dps_packed(
(lv, weight2),
out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"),
return gv
def test_multiple_calls_same_extern():
mod = relax.transform.RunCodegen()(Conv2dx2)["main"], Conv2dx2_after["main"])
def test_default_entry_func():
"""The entry function is not necessarily named "main"
Like `test_multiple_calls_same_extern`, but the main function is
named "func".
before_with_main = Conv2dx2
after_with_main = relax.transform.RunCodegen()(before_with_main)
def rename_main(mod):
mod = mod.clone()
mod["func"] = mod["main"].with_attr("global_symbol", "func")
del mod["main"]
return mod
before_with_func = rename_main(before_with_main)
expected_with_func = rename_main(after_with_main)
after_with_func = relax.transform.RunCodegen()(before_with_func)["func"], after_with_func["func"])
def test_dynamic_shape():
import tvm.relax.backend.contrib.cublas
class Before:
def main(
x: R.Tensor((1, 4096), dtype="float16"),
w1: R.Tensor((4096, "r1"), dtype="float16"),
w2: R.Tensor((4096, "r2"), dtype="float16"),
) -> R.Tuple(R.Tensor((1, "r1"), dtype="float16"), R.Tensor((1, "r2"), dtype="float16")):
r1 = T.int64()
r2 = T.int64()
cls = Before
with R.dataflow():
lv: R.Tensor((1, r1), dtype="float16") = cls.fused_relax_matmul_cublas(x, w1)
lv1: R.Tensor((1, r2), dtype="float16") = cls.fused_relax_matmul_cublas(x, w2)
gv: R.Tuple(
R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16")
) = (lv, lv1)
return gv
def fused_relax_matmul_cublas(
x: R.Tensor((1, 4096), dtype="float16"), w1: R.Tensor((4096, "r1"), dtype="float16")
) -> R.Tensor((1, "r1"), dtype="float16"):
r1 = T.int64()
R.func_attr({"Codegen": "cublas"})
def gv(
x_1: R.Tensor((1, 4096), dtype="float16"),
w1_1: R.Tensor((4096, r1), dtype="float16"),
) -> R.Tensor((1, r1), dtype="float16"):
R.func_attr({"Composite": "cublas.matmul"})
with R.dataflow():
gv_1: R.Tensor((1, r1), dtype="float16") = R.matmul(x_1, w1_1, out_dtype="void")
return gv_1
gv1: R.Tensor((1, r1), dtype="float16") = gv(x, w1)
return gv1
class Expected:
def main(
x: R.Tensor((1, 4096), dtype="float16"),
w1: R.Tensor((4096, "r1"), dtype="float16"),
w2: R.Tensor((4096, "r2"), dtype="float16"),
) -> R.Tuple(R.Tensor((1, "r1"), dtype="float16"), R.Tensor((1, "r2"), dtype="float16")):
r1 = T.int64()
r2 = T.int64()
with R.dataflow():
lv = R.call_dps_packed(
(x, w1),
out_sinfo=R.Tensor((1, r1), dtype="float16"),
lv1 = R.call_dps_packed(
(x, w2),
out_sinfo=R.Tensor((1, r2), dtype="float16"),
gv: R.Tuple(
R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16")
) = (lv, lv1)
return gv
after = relax.transform.RunCodegen()(Before)["main"], Expected["main"])
def test_no_op_for_call_to_tir():
"""Calls to PrimFunc are ignored
RunCodegen should only update calls to Relax functions annotated
with the `"Codegen"` attribute. Calls to any other function type
should be ignored.
This is a regression test. Previous implementations performed an
unconditional cast from `tvm::BaseFunc` to `tvm::relax::Function`,
which produced an error.
class Before:
def main(x: R.Tensor([4], "int64")):
R.func_attr({"relax.force_pure": True})
_ = Before.shape_func(x)
return x
def shape_func(H: T.Buffer(T.int64(4), "int64")):
H[T.int64(0)] = H[T.int64(0)] + T.int64(1)
Expected = Before
After = relax.transform.RunCodegen()(Before), After)
# TODO(@sunggg): test with more complex patterns (e.g., multiple annots, mixed codegens, different ops, const binding)
if __name__ == "__main__":