[CudaGraph] Handle exceptions thrown while capturing cuda graph (#17113)
* [CudaGraph] Handle exceptions thrown while capturing cuda graph
Prior to this commit, an exception thrown during the capture of a cuda
graph would result in `std::terminate` being called. This commit
updates the implementation of `"vm.builtin.cuda_graph.run_or_capture"`
such that a thrown exception can be recovered from, and does not cause
any changes to the state of TVM's cuda graph cache.
- Call to `cudaStreamDestroy` was previously skipped, now moved to a
RAII-style destructor in a `ScopedCUDAStream` class.
- Call to `cudaStreamEndCapture` was previously skipped, end of cuda
graph capture now performed as part of RAII-style destructor for
`CUDACaptureStream` class.
- Restoration of `CUDAThreadEntry::ThreadLocal()->stream` was
previously skipped, now restored as part of RAII-style destructor
for `CUDACaptureStream` class.
- Previously, an error raised from `cudaGraphInstantiate` would leave
the `capture_cache_` in an ill-formed state. Now, the
`capture_cache_` is only updated after a valid
`CUDAGraphCapturedState` has been fully constructed.
* lint fix
* Unit test fix
diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
index dea497e..e8901c0 100644
--- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
@@ -32,6 +32,8 @@
namespace runtime {
namespace relax_vm {
+namespace {
+
struct CUDAGraphCaptureKey {
// The unique index of the capture function within the module
int64_t index;
@@ -67,6 +69,18 @@
/*! \brief The captured state of a CUDA graph */
struct CUDAGraphCapturedState {
+ CUDAGraphCapturedState() {}
+
+ CUDAGraphCapturedState(const CUDAGraphCapturedState&) = delete;
+ CUDAGraphCapturedState(CUDAGraphCapturedState&& other) { *this = std::move(other); }
+
+ CUDAGraphCapturedState& operator=(const CUDAGraphCapturedState&) = delete;
+ CUDAGraphCapturedState& operator=(CUDAGraphCapturedState&& other) {
+ std::swap(states, other.states);
+ std::swap(exec, other.exec);
+ return *this;
+ }
+
~CUDAGraphCapturedState() {
if (exec) {
CUDA_CALL(cudaGraphExecDestroy(exec));
@@ -82,6 +96,43 @@
cudaGraphExec_t exec = nullptr;
};
+class ScopedCUDAStream {
+ public:
+ ScopedCUDAStream() { CUDA_CALL(cudaStreamCreate(&stream_)); }
+ ~ScopedCUDAStream() { cudaStreamDestroy(stream_); }
+ ScopedCUDAStream(const ScopedCUDAStream&) = delete;
+ ScopedCUDAStream(ScopedCUDAStream&&) = delete;
+ ScopedCUDAStream& operator=(const ScopedCUDAStream&) = delete;
+ ScopedCUDAStream& operator=(ScopedCUDAStream&&) = delete;
+
+ operator cudaStream_t() const { return stream_; }
+
+ private:
+ cudaStream_t stream_;
+};
+
+class CUDACaptureStream {
+ public:
+ explicit CUDACaptureStream(cudaGraph_t* graph)
+ : prev_default_stream_(CUDAThreadEntry::ThreadLocal()->stream), output_graph_(graph) {
+ CUDAThreadEntry::ThreadLocal()->stream = capture_stream_;
+
+ CUDA_CALL(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal));
+ }
+ ~CUDACaptureStream() {
+ cudaStreamEndCapture(capture_stream_, output_graph_);
+ CUDAThreadEntry::ThreadLocal()->stream = prev_default_stream_;
+ }
+
+ private:
+ cudaStream_t prev_default_stream_;
+ ScopedCUDAStream capture_stream_;
+
+ cudaGraph_t* output_graph_;
+};
+
+} // namespace
+
/*! \brief The VM extension of CUDA graph. */
class CUDAGraphExtensionNode : public VMExtensionNode {
public:
@@ -107,10 +158,6 @@
return states;
}
- cudaStream_t capture_stream;
- CUDA_CALL(cudaStreamCreate(&capture_stream));
- CUDAGraphCapturedState entry;
-
// Set up arguments for the graph execution
Array<ObjectRef> tuple_args = Downcast<Array<ObjectRef>>(args);
int nargs = static_cast<int>(tuple_args.size());
@@ -130,21 +177,23 @@
// Run the graph in capture mode
cudaGraph_t graph;
- std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);
- CUDA_CALL(cudaStreamBeginCapture(CUDAThreadEntry::ThreadLocal()->stream,
- cudaStreamCaptureModeGlobal));
- vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), tcodes.data(), nargs),
- &capture_func_rv);
+ {
+ CUDACaptureStream capture_stream(&graph);
+ vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), tcodes.data(), nargs),
+ &capture_func_rv);
+ }
+
+ CUDAGraphCapturedState entry;
entry.states = capture_func_rv;
- CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream, &graph));
- std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);
-
- capture_cache_[entry_key] = entry;
- CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_key].exec, graph, NULL, NULL, 0));
- CUDA_CALL(cudaStreamDestroy(capture_stream));
+ CUDA_CALL(cudaGraphInstantiate(&entry.exec, graph, NULL, NULL, 0));
CUDA_CALL(cudaGraphDestroy(graph));
- return entry.states;
+
+ ObjectRef states = entry.states;
+
+ capture_cache_[entry_key] = std::move(entry);
+
+ return states;
}
/*!
diff --git a/tests/python/relax/test_vm_cuda_graph.py b/tests/python/relax/test_vm_cuda_graph.py
index 6a20b6b..49ebcc1 100644
--- a/tests/python/relax/test_vm_cuda_graph.py
+++ b/tests/python/relax/test_vm_cuda_graph.py
@@ -16,10 +16,13 @@
# under the License.
import tvm
-from tvm.script import tir as T, relax as R, ir as I
-from tvm import relax
import tvm.testing
+
+from tvm import relax
+from tvm.script import tir as T, relax as R, ir as I
+
import numpy as np
+import pytest
# fmt: off
@@ -104,5 +107,75 @@
tvm.testing.assert_allclose(y.asnumpy(), y_np, rtol=1e-5, atol=1e-5)
+@tvm.testing.requires_cudagraph
+def test_capture_error_is_recoverable():
+ """Function calls while capturing cudagraph may throw exceptions
+
+ Calls to PackedFuncs may occur within a captured cudaGraph. If a
+ call to that PackedFunc raises an exception while capturing the
+ cudaGraph, throwing exception should cleanly unwind the stack, and
+ the exception may be caught in the calling scope.
+
+ This is a regression test. In previous implementations, an
+ exception thrown while capturing a cudaGraph would skip the call
+ to `cudaStreamEndCapture`, causing additional exceptions to be
+ thrown while freeing memory in TVM destructors. Since C++ does
+ not support stack unwinding from multiple simultaneous exceptions,
+ this would result in immediate `std::terminate`, making it
+ difficult to debug the original error.
+
+ """
+
+ target = tvm.target.Target("cuda")
+ dev = tvm.cuda()
+
+ @tvm.register_func("test_vm_cuda_graph.invalid_impl_for_cudagraph", override=True)
+ def invalid_impl_for_cudagraph(arg_tensor):
+ # Memory allocation/deallocation may not be performed while
+ # capturing a cudaGraph. This passes the warm-up run
+ # performed by "vm.builtin.cuda_graph.run_or_capture", but
+ # throws an exception when the cudaGraph is being captured.
+ _dummy_workspace = tvm.nd.empty([16], "float16", dev)
+ return arg_tensor
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16], "float16")):
+ B = R.add(A, A)
+ C = R.call_pure_packed(
+ "test_vm_cuda_graph.invalid_impl_for_cudagraph",
+ B,
+ sinfo_args=R.Tensor([16], "float16"),
+ )
+ D = R.add(C, C)
+ return D
+
+ with target, tvm.ir.transform.PassContext(config={"relax.backend.use_cuda_graph": True}):
+ Module = tvm.ir.transform.Sequential(
+ [
+ tvm.relax.transform.LegalizeOps(),
+ tvm.tir.transform.DefaultGPUSchedule(),
+ tvm.relax.transform.RemovePurityChecking(),
+ tvm.relax.transform.CallTIRRewrite(),
+ tvm.relax.transform.StaticPlanBlockMemory(),
+ tvm.relax.transform.RewriteCUDAGraph(),
+ ]
+ )(Module)
+
+ assert "cuda_graph_alloc" in Module, (
+ "Validity of unit test requires the call to `invalid_impl_for_cudagraph` "
+ "to have been captured by RewriteCUDAGraph."
+ )
+
+ built = tvm.relax.build(Module, target=target)
+ vm = tvm.relax.VirtualMachine(built, dev)
+
+ arg = tvm.nd.array(np.arange(16).astype("float16"), dev)
+
+ with pytest.raises(tvm.TVMError):
+ vm["main"](arg)
+
+
if __name__ == "__main__":
tvm.testing.main()