[CudaGraph] Handle exceptions thrown while capturing cuda graph (#17113)

* [CudaGraph] Handle exceptions thrown while capturing cuda graph

Prior to this commit, an exception thrown during the capture of a cuda
graph would result in `std::terminate` being called.  This commit
updates the implementation of `"vm.builtin.cuda_graph.run_or_capture"`
such that a thrown exception can be recovered from, and does not cause
any changes to the state of TVM's cuda graph cache.

- Call to `cudaStreamDestroy` was previously skipped, now moved to a
  RAII-style destructor in a `ScopedCUDAStream` class.

- Call to `cudaStreamEndCapture` was previously skipped, end of cuda
  graph capture now performed as part of RAII-style destructor for
  `CUDACaptureStream` class.

- Restoration of `CUDAThreadEntry::ThreadLocal()->stream` was
  previously skipped, now restored as part of RAII-style destructor
  for `CUDACaptureStream` class.

- Previously, an error raised from `cudaGraphInstantiate` would leave
  the `capture_cache_` in an ill-formed state.  Now, the
  `capture_cache_` is only updated after a valid
  `CUDAGraphCapturedState` has been fully constructed.

* lint fix

* Unit test fix
diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
index dea497e..e8901c0 100644
--- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
@@ -32,6 +32,8 @@
 namespace runtime {
 namespace relax_vm {
 
+namespace {
+
 struct CUDAGraphCaptureKey {
   // The unique index of the capture function within the module
   int64_t index;
@@ -67,6 +69,18 @@
 
 /*! \brief The captured state of a CUDA graph */
 struct CUDAGraphCapturedState {
+  CUDAGraphCapturedState() {}
+
+  CUDAGraphCapturedState(const CUDAGraphCapturedState&) = delete;
+  CUDAGraphCapturedState(CUDAGraphCapturedState&& other) { *this = std::move(other); }
+
+  CUDAGraphCapturedState& operator=(const CUDAGraphCapturedState&) = delete;
+  CUDAGraphCapturedState& operator=(CUDAGraphCapturedState&& other) {
+    std::swap(states, other.states);
+    std::swap(exec, other.exec);
+    return *this;
+  }
+
   ~CUDAGraphCapturedState() {
     if (exec) {
       CUDA_CALL(cudaGraphExecDestroy(exec));
@@ -82,6 +96,43 @@
   cudaGraphExec_t exec = nullptr;
 };
 
+class ScopedCUDAStream {
+ public:
+  ScopedCUDAStream() { CUDA_CALL(cudaStreamCreate(&stream_)); }
+  ~ScopedCUDAStream() { cudaStreamDestroy(stream_); }
+  ScopedCUDAStream(const ScopedCUDAStream&) = delete;
+  ScopedCUDAStream(ScopedCUDAStream&&) = delete;
+  ScopedCUDAStream& operator=(const ScopedCUDAStream&) = delete;
+  ScopedCUDAStream& operator=(ScopedCUDAStream&&) = delete;
+
+  operator cudaStream_t() const { return stream_; }
+
+ private:
+  cudaStream_t stream_;
+};
+
+class CUDACaptureStream {
+ public:
+  explicit CUDACaptureStream(cudaGraph_t* graph)
+      : prev_default_stream_(CUDAThreadEntry::ThreadLocal()->stream), output_graph_(graph) {
+    CUDAThreadEntry::ThreadLocal()->stream = capture_stream_;
+
+    CUDA_CALL(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal));
+  }
+  ~CUDACaptureStream() {
+    cudaStreamEndCapture(capture_stream_, output_graph_);
+    CUDAThreadEntry::ThreadLocal()->stream = prev_default_stream_;
+  }
+
+ private:
+  cudaStream_t prev_default_stream_;
+  ScopedCUDAStream capture_stream_;
+
+  cudaGraph_t* output_graph_;
+};
+
+}  // namespace
+
 /*! \brief The VM extension of CUDA graph. */
 class CUDAGraphExtensionNode : public VMExtensionNode {
  public:
@@ -107,10 +158,6 @@
       return states;
     }
 
-    cudaStream_t capture_stream;
-    CUDA_CALL(cudaStreamCreate(&capture_stream));
-    CUDAGraphCapturedState entry;
-
     // Set up arguments for the graph execution
     Array<ObjectRef> tuple_args = Downcast<Array<ObjectRef>>(args);
     int nargs = static_cast<int>(tuple_args.size());
@@ -130,21 +177,23 @@
 
     // Run the graph in capture mode
     cudaGraph_t graph;
-    std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);
-    CUDA_CALL(cudaStreamBeginCapture(CUDAThreadEntry::ThreadLocal()->stream,
-                                     cudaStreamCaptureModeGlobal));
 
-    vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), tcodes.data(), nargs),
-                            &capture_func_rv);
+    {
+      CUDACaptureStream capture_stream(&graph);
+      vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), tcodes.data(), nargs),
+                              &capture_func_rv);
+    }
+
+    CUDAGraphCapturedState entry;
     entry.states = capture_func_rv;
-    CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream, &graph));
-    std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);
-
-    capture_cache_[entry_key] = entry;
-    CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_key].exec, graph, NULL, NULL, 0));
-    CUDA_CALL(cudaStreamDestroy(capture_stream));
+    CUDA_CALL(cudaGraphInstantiate(&entry.exec, graph, NULL, NULL, 0));
     CUDA_CALL(cudaGraphDestroy(graph));
-    return entry.states;
+
+    ObjectRef states = entry.states;
+
+    capture_cache_[entry_key] = std::move(entry);
+
+    return states;
   }
 
   /*!
diff --git a/tests/python/relax/test_vm_cuda_graph.py b/tests/python/relax/test_vm_cuda_graph.py
index 6a20b6b..49ebcc1 100644
--- a/tests/python/relax/test_vm_cuda_graph.py
+++ b/tests/python/relax/test_vm_cuda_graph.py
@@ -16,10 +16,13 @@
 # under the License.
 
 import tvm
-from tvm.script import tir as T, relax as R, ir as I
-from tvm import relax
 import tvm.testing
+
+from tvm import relax
+from tvm.script import tir as T, relax as R, ir as I
+
 import numpy as np
+import pytest
 
 
 # fmt: off
@@ -104,5 +107,75 @@
     tvm.testing.assert_allclose(y.asnumpy(), y_np, rtol=1e-5, atol=1e-5)
 
 
+@tvm.testing.requires_cudagraph
+def test_capture_error_is_recoverable():
+    """Function calls while capturing cudagraph may throw exceptions
+
+    Calls to PackedFuncs may occur within a captured cudaGraph.  If a
+    call to that PackedFunc raises an exception while capturing the
+    cudaGraph, throwing exception should cleanly unwind the stack, and
+    the exception may be caught in the calling scope.
+
+    This is a regression test.  In previous implementations, an
+    exception thrown while capturing a cudaGraph would skip the call
+    to `cudaStreamEndCapture`, causing additional exceptions to be
+    thrown while freeing memory in TVM destructors.  Since C++ does
+    not support stack unwinding from multiple simultaneous exceptions,
+    this would result in immediate `std::terminate`, making it
+    difficult to debug the original error.
+
+    """
+
+    target = tvm.target.Target("cuda")
+    dev = tvm.cuda()
+
+    @tvm.register_func("test_vm_cuda_graph.invalid_impl_for_cudagraph", override=True)
+    def invalid_impl_for_cudagraph(arg_tensor):
+        # Memory allocation/deallocation may not be performed while
+        # capturing a cudaGraph.  This passes the warm-up run
+        # performed by "vm.builtin.cuda_graph.run_or_capture", but
+        # throws an exception when the cudaGraph is being captured.
+        _dummy_workspace = tvm.nd.empty([16], "float16", dev)
+        return arg_tensor
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor([16], "float16")):
+            B = R.add(A, A)
+            C = R.call_pure_packed(
+                "test_vm_cuda_graph.invalid_impl_for_cudagraph",
+                B,
+                sinfo_args=R.Tensor([16], "float16"),
+            )
+            D = R.add(C, C)
+            return D
+
+    with target, tvm.ir.transform.PassContext(config={"relax.backend.use_cuda_graph": True}):
+        Module = tvm.ir.transform.Sequential(
+            [
+                tvm.relax.transform.LegalizeOps(),
+                tvm.tir.transform.DefaultGPUSchedule(),
+                tvm.relax.transform.RemovePurityChecking(),
+                tvm.relax.transform.CallTIRRewrite(),
+                tvm.relax.transform.StaticPlanBlockMemory(),
+                tvm.relax.transform.RewriteCUDAGraph(),
+            ]
+        )(Module)
+
+    assert "cuda_graph_alloc" in Module, (
+        "Validity of unit test requires the call to `invalid_impl_for_cudagraph` "
+        "to have been captured by RewriteCUDAGraph."
+    )
+
+    built = tvm.relax.build(Module, target=target)
+    vm = tvm.relax.VirtualMachine(built, dev)
+
+    arg = tvm.nd.array(np.arange(16).astype("float16"), dev)
+
+    with pytest.raises(tvm.TVMError):
+        vm["main"](arg)
+
+
 if __name__ == "__main__":
     tvm.testing.main()