blob: f9438bb487e6d6da44c6b274eb0162bdc6340810 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/runtime/vm/cuda_graph_builtin.cc
* \brief The CUDA graph related builtin functions for Relax virtual machine.
*/
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/vm/vm.h>
#include "../../../support/utils.h"
#include "../../cuda/cuda_common.h"
namespace tvm {
namespace runtime {
namespace vm {
namespace {
struct CUDAGraphCaptureKey {
// The unique index of the capture function within the module
int64_t index;
// The symbolic variables the capture function depends on. When the capture function is ran with
// different symbolic variable values, the CUDA graph will be re-captured as a different version,
// identified by this shape tuple. This is default constructed as an empty tuple.
ffi::Shape shape_expr;
CUDAGraphCaptureKey(int64_t index, const ffi::Optional<ffi::Shape>& shape_expr) : index(index) {
if (shape_expr) {
this->shape_expr = shape_expr.value();
}
}
};
struct CUDAGraphCaptureKeyHash {
size_t operator()(const CUDAGraphCaptureKey& key) const {
std::hash<int64_t> hash_fn;
size_t hash = hash_fn(key.index);
for (const auto& shape : key.shape_expr) {
support::HashCombine(hash, hash_fn(shape));
}
return hash;
}
};
struct CUDAGraphCaptureKeyEqual {
bool operator()(const CUDAGraphCaptureKey& lhs, const CUDAGraphCaptureKey& rhs) const {
return lhs.index == rhs.index && std::equal(lhs.shape_expr.begin(), lhs.shape_expr.end(),
rhs.shape_expr.begin(), rhs.shape_expr.end());
}
};
/*! \brief The captured state of a CUDA graph */
struct CUDAGraphCapturedState {
CUDAGraphCapturedState() {}
CUDAGraphCapturedState(const CUDAGraphCapturedState&) = delete;
CUDAGraphCapturedState(CUDAGraphCapturedState&& other) { *this = std::move(other); }
CUDAGraphCapturedState& operator=(const CUDAGraphCapturedState&) = delete;
CUDAGraphCapturedState& operator=(CUDAGraphCapturedState&& other) {
std::swap(states, other.states);
std::swap(exec, other.exec);
return *this;
}
~CUDAGraphCapturedState() {
if (exec) {
CUDA_CALL(cudaGraphExecDestroy(exec));
}
}
/*!
* \brief Tuple of intemediate tensors in the capture func that will be used outside the
* capture func
*/
ffi::ObjectRef states;
/*! \brief The instantiated cuda graph */
cudaGraphExec_t exec = nullptr;
};
class ScopedCUDAStream {
public:
ScopedCUDAStream() { CUDA_CALL(cudaStreamCreate(&stream_)); }
~ScopedCUDAStream() { cudaStreamDestroy(stream_); }
ScopedCUDAStream(const ScopedCUDAStream&) = delete;
ScopedCUDAStream(ScopedCUDAStream&&) = delete;
ScopedCUDAStream& operator=(const ScopedCUDAStream&) = delete;
ScopedCUDAStream& operator=(ScopedCUDAStream&&) = delete;
operator cudaStream_t() const { return stream_; }
private:
cudaStream_t stream_;
};
class CUDACaptureStream {
public:
explicit CUDACaptureStream(cudaGraph_t* graph) : output_graph_(graph) {
CUDA_CALL(cudaGetDevice(&device_id_));
TVM_FFI_CHECK_SAFE_CALL(
TVMFFIEnvSetStream(kDLCUDA, device_id_, capture_stream_,
reinterpret_cast<TVMFFIStreamHandle*>(&prev_default_stream_)));
CUDA_CALL(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal));
}
~CUDACaptureStream() noexcept(false) {
cudaStreamEndCapture(capture_stream_, output_graph_);
TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(kDLCUDA, device_id_, prev_default_stream_, nullptr));
}
private:
int device_id_;
cudaStream_t prev_default_stream_;
ScopedCUDAStream capture_stream_;
cudaGraph_t* output_graph_;
};
} // namespace
/*! \brief The VM extension of CUDA graph. */
class CUDAGraphExtensionNode : public VMExtensionNode {
public:
/*!
* \brief Launch the cuda graph if it has been cached, otherwise execute it in capture mode.
* \param vm The virtual machine.
* \param capture_func The function of type (args...) -> Tuple[ffi::ObjectRef], where 'args' are
* the static arguments that are the same for all invocations of the capture function, the
* returned tuple contains the intermediate tensors that will be used outside the capture
* function.
* \param args The static arguments of the capture function
* \param entry_index The unique index of the capture function used for lookup.
* \return The return value of the capture function.
*/
ffi::ObjectRef RunOrCapture(VirtualMachine* vm, const ffi::ObjectRef& capture_func, Any args,
int64_t entry_index, ffi::Optional<ffi::Shape> shape_expr) {
CUDAGraphCaptureKey entry_key{entry_index, shape_expr};
if (auto it = capture_cache_.find(entry_key); it != capture_cache_.end()) {
// Launch CUDA graph
const auto& [states, exec] = it->second;
int device_id;
CUDA_CALL(cudaGetDevice(&device_id));
CUDA_CALL(
cudaGraphLaunch(exec, static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA, device_id))));
return states;
}
// Set up arguments for the graph execution
ffi::Array<Any> tuple_args = args.cast<ffi::Array<Any>>();
int nargs = static_cast<int>(tuple_args.size());
std::vector<AnyView> packed_args(nargs);
for (int i = 0; i < nargs; ++i) {
packed_args[i] = tuple_args[i];
}
ffi::Any capture_func_rv;
// Run the function without CUDA graph. This is a warm up step to do necessary initialization
// of the CUDA module such as loading module data, setting kernel attributes.
vm->InvokeClosurePacked(capture_func, ffi::PackedArgs(packed_args.data(), nargs),
&capture_func_rv);
// Run the graph in capture mode
cudaGraph_t graph;
{
CUDACaptureStream capture_stream(&graph);
vm->InvokeClosurePacked(capture_func, ffi::PackedArgs(packed_args.data(), nargs),
&capture_func_rv);
}
CUDAGraphCapturedState entry;
entry.states = capture_func_rv.cast<ffi::ObjectRef>();
CUDA_CALL(cudaGraphInstantiate(&entry.exec, graph, NULL, NULL, 0));
CUDA_CALL(cudaGraphDestroy(graph));
ffi::ObjectRef states = entry.states;
capture_cache_[entry_key] = std::move(entry);
return states;
}
/*!
* \brief Get the cached allocation from the cache or run the allocation function.
* \param vm The virtual machine.
* \param alloc_func The function of type () -> ffi::ObjectRef, where the returned object is the
* tuple of allocated storage objects.
* \param entry_index The unique index of the allocation function used for lookup.
*/
ffi::ObjectRef GetCachedAllocation(VirtualMachine* vm, const ffi::ObjectRef& alloc_func,
int64_t entry_index) {
if (auto it = alloc_cache_.find(entry_index); it != alloc_cache_.end()) {
return it->second;
}
ffi::Any alloc_func_rv;
vm->InvokeClosurePacked(alloc_func, ffi::PackedArgs(nullptr, 0), &alloc_func_rv);
ffi::ObjectRef alloc_result = alloc_func_rv.cast<ffi::ObjectRef>();
alloc_cache_[entry_index] = alloc_result;
return alloc_result;
}
static constexpr const bool _type_mutable = true;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("vm.CUDAGraphExtension", CUDAGraphExtensionNode,
VMExtensionNode);
private:
/*!
* \brief The cache of captured cuda graphs. The key is a unique index for the capture function.
* The value is the result of the capture.
*/
std::unordered_map<CUDAGraphCaptureKey, CUDAGraphCapturedState, CUDAGraphCaptureKeyHash,
CUDAGraphCaptureKeyEqual>
capture_cache_;
/*!
* \brief The cache of allocations. The key is a unique index for the allocation function.
* The value is the cached allocations, which is a tuple of storages.
*/
std::unordered_map<int64_t, ffi::ObjectRef> alloc_cache_;
};
/*! Managed reference to CUDAGraphExtensionNode */
class CUDAGraphExtension : public VMExtension {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CUDAGraphExtension, VMExtension,
CUDAGraphExtensionNode);
static CUDAGraphExtension Create() {
auto data_ = ffi::make_object<CUDAGraphExtensionNode>();
return CUDAGraphExtension(std::move(data_));
}
};
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("vm.builtin.cuda_graph.run_or_capture",
[](ffi::PackedArgs args, ffi::Any* rv) {
TVM_FFI_ICHECK(args.size() == 5 || args.size() == 4);
VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]);
auto extension = vm->GetOrCreateExtension<CUDAGraphExtension>();
auto capture_func = args[1].cast<ffi::ObjectRef>();
Any func_args = args[2];
int64_t entry_index = args[3].cast<int64_t>();
ffi::Optional<ffi::Shape> shape_expr = std::nullopt;
if (args.size() == 5) {
shape_expr = args[4].cast<ffi::Shape>();
}
*rv = extension->RunOrCapture(vm, capture_func, func_args, entry_index,
shape_expr);
})
.def_packed("vm.builtin.cuda_graph.get_cached_alloc", [](ffi::PackedArgs args, ffi::Any* rv) {
TVM_FFI_ICHECK_EQ(args.size(), 3);
VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]);
auto extension = vm->GetOrCreateExtension<CUDAGraphExtension>();
auto alloc_func = args[1].cast<ffi::ObjectRef>();
int64_t entry_index = args[2].cast<int64_t>();
*rv = extension->GetCachedAllocation(vm, alloc_func, entry_index);
});
}
} // namespace vm
} // namespace runtime
} // namespace tvm