blob: d04fd6bdab1b4eb4413dea54695f3891ebd69c0e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import relax
from tvm.script import tir as T, relax as R, ir as I
import numpy as np
import pytest
# fmt: off
@I.ir_module
class Module:
@R.function(pure=False)
def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"):
cls = Module
R.func_attr({"global_symbol": "main"})
gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object),))
storage: R.Object = gv[0]
alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
_: R.Tuple = cls.add(x, alloc)
storage1: R.Object = gv[1]
gv1: R.Tuple(R.Tensor(dtype="float32"), R.Object, R.Object) = (alloc, storage1, storage)
gv2: R.Tuple(R.Tensor((16, 16), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, gv1, R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((16, 16), dtype="float32")),))
storage2: R.Object = R.vm.alloc_storage(R.shape((1024,)), R.prim_value(0), R.dtype("uint8"))
alloc3 = R.vm.alloc_tensor(storage2, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
lv4: R.Tensor((16, 16), dtype="float32") = gv2[0]
_3: R.Tuple = cls.add(lv4, alloc3)
lv5: R.Tensor(dtype="float32") = alloc3
return lv5
@T.prim_func
def add(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
T.func_attr({"global_symbol": "add"})
with T.block("root"):
for i in T.thread_binding(16, thread="threadIdx.x"):
for j in range(16):
with T.block("update"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + T.float32(1)
@R.function
def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
R.func_attr({"global_symbol": "cuda_graph_alloc"})
storage: R.Object = R.vm.alloc_storage(R.shape((1024,)), R.prim_value(0), R.dtype("uint8"))
storage1: R.Object = R.vm.alloc_storage(R.shape((1024,)), R.prim_value(0), R.dtype("uint8"))
gv: R.Tuple(R.Object, R.Object) = (storage, storage1)
return gv
@R.function(pure=False)
def cuda_graph_capture(alloc: R.Tensor((16, 16), dtype="float32"), storage1: R.Object, storage: R.Object) -> R.Tuple(R.Tensor((16, 16), dtype="float32")):
cls = Module
R.func_attr({"global_symbol": "cuda_graph_capture"})
lv0: R.Tensor((16, 16), dtype="float32") = alloc
alloc1 = R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
_1: R.Tuple = cls.add(lv0, alloc1)
lv1: R.Tensor(dtype="float32") = alloc1
lv2: R.Tuple(R.Tensor(dtype="float32")) = (lv1,)
lv3: R.Tensor(dtype="float32") = lv2[0]
alloc2 = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
_2: R.Tuple = cls.add(lv3, alloc2)
lv4: R.Tensor(dtype="float32") = alloc2
gv: R.Tuple(R.Tensor(dtype="float32")) = (lv4,)
return gv
# fmt: on
def codegen(mod, target, exec_mode="bytecode"):
builder = relax.ExecBuilder()
leftover_mod = relax.vm_build._vmcodegen(builder, mod, exec_mode=exec_mode)
tir_mod = relax.vm_build._filter_tir(leftover_mod)
return relax.vm_build._vmlink(builder, target, tir_mod)
@tvm.testing.requires_cuda
def test_vm_run():
mod = Module
target = tvm.target.Target("cuda", host="llvm")
ex = codegen(mod, target)
dev = tvm.cuda(0)
vm = relax.VirtualMachine(ex, dev)
x_np = np.random.uniform(size=(16, 16)).astype("float32")
x = tvm.runtime.tensor(x_np, dev)
y = vm["main"](x)
y_np = x_np + 1.0 + 1.0 + 1.0 + 1.0
tvm.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5)
@tvm.testing.requires_cudagraph
def test_capture_error_is_recoverable():
"""Function calls while capturing cudagraph may throw exceptions
Calls to PackedFuncs may occur within a captured cudaGraph. If a
call to that PackedFunc raises an exception while capturing the
cudaGraph, throwing exception should cleanly unwind the stack, and
the exception may be caught in the calling scope.
This is a regression test. In previous implementations, an
exception thrown while capturing a cudaGraph would skip the call
to `cudaStreamEndCapture`, causing additional exceptions to be
thrown while freeing memory in TVM destructors. Since C++ does
not support stack unwinding from multiple simultaneous exceptions,
this would result in immediate `std::terminate`, making it
difficult to debug the original error.
"""
target = tvm.target.Target("cuda")
dev = tvm.cuda()
@tvm.register_global_func("test_vm_cuda_graph.invalid_impl_for_cudagraph", override=True)
def invalid_impl_for_cudagraph(arg_tensor):
# Memory allocation/deallocation may not be performed while
# capturing a cudaGraph. This passes the warm-up run
# performed by "vm.builtin.cuda_graph.run_or_capture", but
# throws an exception when the cudaGraph is being captured.
_dummy_workspace = tvm.runtime.empty([16], "float16", dev)
return arg_tensor
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([16], "float16")):
B = R.add(A, A)
C = R.call_pure_packed(
"test_vm_cuda_graph.invalid_impl_for_cudagraph",
B,
sinfo_args=R.Tensor([16], "float16"),
)
D = R.add(C, C)
return D
with target, tvm.ir.transform.PassContext(config={"relax.backend.use_cuda_graph": True}):
Module = tvm.ir.transform.Sequential(
[
tvm.relax.transform.LegalizeOps(),
tvm.tir.transform.DefaultGPUSchedule(),
tvm.relax.transform.RemovePurityChecking(),
tvm.relax.transform.CallTIRRewrite(),
tvm.relax.transform.StaticPlanBlockMemory(),
tvm.relax.transform.RewriteCUDAGraph(),
]
)(Module)
assert "cuda_graph_alloc" in Module, (
"Validity of unit test requires the call to `invalid_impl_for_cudagraph` "
"to have been captured by RewriteCUDAGraph."
)
built = tvm.compile(Module, target=target)
vm = tvm.relax.VirtualMachine(built, dev)
arg = tvm.runtime.tensor(np.arange(16).astype("float16"), dev)
with pytest.raises(tvm.TVMError):
vm["main"](arg)
if __name__ == "__main__":
tvm.testing.main()