blob: 36f7697237e210a4a623f29a9d5371c34ff382df [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/runtime/vm/paged_kv_cache.cc
* \brief Runtime paged KV cache object for language models.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/disco/disco_worker.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/memory/memory_manager.h>
#include <tvm/runtime/tensor.h>
#include <algorithm>
#include <numeric>
#include <unordered_map>
#include <utility>
#include <vector>
#include "attn_backend.h"
#include "attn_utils.h"
#include "kv_state.h"
namespace tvm {
namespace runtime {
namespace vm {
//-------------------------------------------
// We keep the implementation private as
// they may subject to future changes.
//
// Users can interact with it through the
// runtime API function calls
//-------------------------------------------
/*!
* \brief The paged KV cache for attention.
* - It supports managing the K/V data of **multiple sequences**.
* - It manages K/V values by doing paging along the sequence-length
* dimension with a configured page size.
* - To add a sequence to the cache, use AddSequence with a provided
* unique integer sequence id.
* - The basic example use of the paged KV cache after initialization
* in each round of model forwarding is the following:
* - step 1. use `BeginForward` to specify the list of sequence ids
* together with the lengths of append,
* - step 2. use `Attention` to pass in the q/k/v values regarding
* the sequences and lengths specified in `BeginForward`. The
* attention is computed between input queries and the history
* key/values plus the input key/values. The input key/values
* will be added into the KV cache as well.
* - step 3. use `EndForward` to mark the end of forwarding this round.
* After calling `EndForward`, it is required to call `BeginForward`
* before calling any `Attention`.
*/
class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
private:
/********************* Configuration *********************/
/*! \brief The page size (the sequence length each page manages) of the cache. */
const int64_t page_size_;
/*! \brief The number of layers in the model. */
const int64_t num_layers_;
/*! \brief The beginning layer id offset. */
const int64_t layer_id_begin_offset_;
/*! \brief The ending layer id offset. */
const int64_t layer_id_end_offset_;
/*! \brief The number of query/output heads in the model. */
const int64_t num_qo_heads_;
/*! \brief The number of key/value heads in the model. */
const int64_t num_kv_heads_;
/*! \brief The number of features each head has. */
const int64_t qk_head_dim_;
/*!
* \brief The number of features each head has for V.
* For layers that use multi-head attention, this field is overriden by qk_head_dim.
*/
const int64_t v_head_dim_;
/*! \brief The number of total pages allocated in KV cache. */
const int64_t num_total_pages_;
/*! \brief The maximum total sequence length in a prefill. */
const int64_t prefill_chunk_size_;
/*! \brief A boolean flag indicating if the KV cache supports sliding window. */
const bool support_sliding_window_;
/*! \brief A boolean flag indicating if the KV cache has per layer sliding window. */
const bool support_layer_sliding_window_;
/*! \brief The attention kinds for each layer. */
const std::vector<AttnKind> attn_kinds_;
/*! \brief The RoPE application mode of KV cache.*/
const RoPEMode rope_mode_;
/*! \brief The RoPE scale. */
const double rotary_scale_;
/*! \brief The RoPE theta. */
const double rotary_theta_;
/*! \brief The optional RoPE extension factors for RoPE scaling. */
const ffi::Optional<Tensor> rope_ext_factors_;
/*! \brief The KV cache dtype. */
const DataType kv_dtype_;
/*! \brief We fix int32 to be the index dtype of auxiliary data. */
const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1));
/********************* Page Structures *********************/
/*!
* \brief The KV data managed by the KV cache.
* If KV transfer function is specifed, pages_ will be allocated by NVSHMEM as a whole Tensor.
* pages_ will contain tensor view of each layer.
* Otherwise, pages_ has `num_layers` Tensors, each of them
* has layout (num_pages, 2, num_heads, page_size, qk_head_dim).
* Along on the "2" dimension, index 0 stands for K and 1 stands for V.
*/
std::vector<Tensor> pages_;
/*! \brief The whole KV cache allocated by NVSHMEM*/
Tensor nvshmem_pages_;
/*! \brief The list of ids of released pages for page reuse. */
std::vector<int32_t> free_page_ids_;
/*! \brief The mapping from sequence ids to sequences. */
std::unordered_map<int64_t, Sequence> seq_map_;
/********************* Sequence Block Structures *********************/
/*! \brief The list of all blocks once allocated. */
std::vector<Block> global_block_pool_;
/*! \brief The list of free available blocks (in their indices). */
std::vector<int32_t> free_block_idx_;
/*********** Current Batch Info & Auxiliary Arrays on Device ***********/
//-------------------------------------------
// The following fields are auxiliary arrays on device.
// All of them are directly derivable from the fields above.
// We store them for efficient execution of attentions,
// cache append, etc.
//-------------------------------------------
/*!
* \brief A boolean flag indicating if the auxiliary arrays are dirty.
* If it is dirty, an explicit "ComputeStreamWaitForCopyStream" should be invoked.
*/
bool dirty_aux_data_device_ = false;
/*! \brief The batch size of the current round of forwarding. */
int64_t cur_batch_size_;
/*! \brief The ids of the sequences in the current round of forwarding. */
ffi::Shape cur_seq_ids_;
/*! \brief The append lengths of the sequences in the current round of forwarding. */
ffi::Shape cur_append_lengths_;
/*! \brief Whether the current batch of sequences are token chains (not token trees). */
std::vector<bool> is_chain_on_depths_;
/*! \brief Number of fork depth in the current round of forward. */
int num_depths_;
/*! \brief Whether to compute attention after appending KV into cache or not. */
bool append_before_attn_;
/*! \brief Whether to use decode kernel for each depth. (see GetChunkedBlockIds) */
std::vector<bool> use_decode_kernel_;
/*! \brief Whether the attention request is a decode request, set in BeginForwardFunction. */
bool is_decode_request_;
/*! \brief The KV transfer recver disco group's PE offset in this forward.
If no KV is transfered, recver is -1.
Assume that all the KV are transfered to the same recver in the forward.
todo: support multiple recver. */
bool transfer_kv_;
bool page_to_page_transfer_kv_;
/*! \brief The auxiliary data manager for attention. */
std::unique_ptr<PagedKVCacheAuxDataManager> aux_data_manager_;
// Temporary arrays to store intermediate attention results.
Tensor temp_attn_q_device_;
Tensor temp_attn_k_device_;
Tensor temp_attn_v_device_;
Tensor temp_attn_output_device_;
Tensor temp_attn_lse_device_;
Tensor merged_attn_lse_device_;
std::vector<Tensor> temp_int_attn_workspace_;
std::vector<Tensor> temp_int_pinned_attn_workspace_;
Tensor temp_float_attn_workspace_;
//-------------------------------------------
// Below are the auxiliary data structure on CPU.
// We make them class members to avoid repetitive allocation time in BeginForward.
//-------------------------------------------
std::vector<HostMemoryVector> qo_indptr_on_depths_host_;
std::vector<HostMemoryVector> page_indptr_on_depths_host_;
std::vector<HostMemoryVector> page_indices_on_depths_host_;
std::vector<HostMemoryVector> page_indptr_sliding_window_on_depths_host_;
std::vector<HostMemoryVector> page_indices_sliding_window_on_depths_host_;
std::vector<HostMemoryVector> last_page_len_on_depths_host_;
std::vector<HostMemoryVector> sliding_window_offset_on_depths_host_;
std::vector<HostMemoryVector> sink_size_on_depths_host_;
std::vector<HostMemoryVector> k_rope_pos_offset_on_depths_host_;
std::vector<HostMemoryVector> k_rope_pos_offset_sliding_window_on_depths_host_;
HostMemoryVector k_ragged_rope_pos_offset_host_;
HostMemoryVector q_rope_position_map_host_;
HostMemoryVector append_position_map_host_;
HostMemoryVector cur_append_lengths_indptr_host_;
std::vector<HostMemoryVector> tree_attn_mask_host_;
std::vector<HostMemoryVector> tree_attn_mn_indptr_host_;
HostMemoryVector commit_copy_length_indptr_host_;
HostMemoryVector commit_copy_src_pos_in_page_table_host_;
HostMemoryVector commit_copy_dst_pos_in_page_table_host_;
HostMemoryVector kv_transfer_remote_position_map_host_;
HostMemoryVector kv_transfer_recver_id_host_;
HostMemoryVector kv_transfer_page_to_page_local_position_map_host_;
HostMemoryVector kv_transfer_page_to_page_remote_position_map_host_;
HostMemoryVector kv_transfer_page_to_page_recver_id_host_;
//-------------------------------------------
// For efficient memory management, the actual sizes of the arrays
// above are over allocated.
// We create a view for the actual shapes of each of the arrays
// after each synchronization and pass these views as input for
// attention/append.
//-------------------------------------------
Tensor cur_append_length_indptr_view_;
Tensor k_ragged_rope_pos_offset_view_;
Tensor q_rope_position_map_view_;
Tensor append_position_map_view_;
Tensor kv_transfer_remote_position_map_view_;
Tensor kv_transfer_recver_id_view_;
Tensor kv_transfer_page_to_page_local_position_map_view_;
Tensor kv_transfer_page_to_page_remote_position_map_view_;
Tensor kv_transfer_page_to_page_recver_id_view_;
Tensor temp_attn_output_view_;
Tensor temp_attn_lse_view_;
Tensor merged_attn_lse_view_;
std::vector<Tensor> qo_indptr_on_depths_view_;
std::vector<Tensor> page_indptr_on_depths_view_;
std::vector<Tensor> page_indices_on_depths_view_;
std::vector<Tensor> page_indptr_sliding_window_on_depths_view_;
std::vector<Tensor> page_indices_sliding_window_on_depths_view_;
std::vector<Tensor> length_info_on_depths_view_;
std::vector<Tensor> layer_sliding_window_length_info_on_depths_view_;
std::vector<Tensor> k_rope_pos_offset_view_;
std::vector<Tensor> k_rope_pos_offset_sliding_window_view_;
std::vector<Tensor> tree_attn_mask_view_;
std::vector<Tensor> tree_attn_mn_indptr_view_;
ffi::Optional<ffi::Function> f_transpose_append_mha_;
ffi::Optional<ffi::Function> f_transpose_append_mla_;
ffi::Optional<ffi::Function> f_transfer_kv_;
ffi::Optional<ffi::Function> f_transfer_kv_page_to_page_ = std::nullopt;
ffi::Function f_compact_copy_;
std::unique_ptr<RaggedPrefillFunc> f_attention_prefill_ragged_;
std::unique_ptr<PagedPrefillFunc> f_attention_prefill_;
std::unique_ptr<PagedDecodeFunc> f_attention_decode_;
std::unique_ptr<PagedPrefillFunc> f_attention_prefill_sliding_window_;
std::unique_ptr<PagedDecodeFunc> f_attention_decode_sliding_window_;
std::unique_ptr<PagedPrefillTreeMaskFunc> f_attention_prefill_with_tree_mask_paged_kv_;
std::unique_ptr<RaggedPrefillTreeMaskFunc> f_attention_prefill_with_tree_mask_;
std::unique_ptr<PagedPrefillFunc> f_mla_prefill_;
ffi::Array<ffi::Function> f_merge_inplace_;
ffi::Function f_split_rotary_;
ffi::Function f_copy_single_page_;
ffi::Optional<ffi::Function> f_debug_get_kv_;
/*! \brief The device this PagedKVCache runs on. */
Device device_;
/*! \brief The device stream for the default computation operations. */
TVMStreamHandle compute_stream_ = nullptr;
/*! \brief The device stream for copying auxiliary data structure to GPU. */
TVMStreamHandle copy_stream_ = nullptr;
/*! \brief The device stream for KV transfer */
TVMStreamHandle kv_transfer_stream_ = nullptr;
public:
/*! \brief Constructor. Take the cache configuration and initialize the Tensors. */
explicit PagedAttentionKVCacheObj(
int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset,
int64_t layer_id_end_offset, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim,
int64_t v_head_dim, std::vector<AttnKind> attn_kinds, int64_t reserved_num_seqs,
int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window,
RoPEMode rope_mode, double rotary_scale, double rotary_theta,
ffi::Optional<Tensor> rope_ext_factors, bool enable_kv_transfer, DLDataType dtype,
Device device, ffi::Optional<ffi::Function> f_transpose_append_mha,
ffi::Optional<ffi::Function> f_transpose_append_mla, ffi::Function f_compact_copy,
std::unique_ptr<RaggedPrefillFunc> f_attention_prefill_ragged,
std::unique_ptr<PagedPrefillFunc> f_attention_prefill,
std::unique_ptr<PagedDecodeFunc> f_attention_decode,
std::unique_ptr<PagedPrefillFunc> f_attention_prefill_sliding_window,
std::unique_ptr<PagedDecodeFunc> f_attention_decode_sliding_window,
std::unique_ptr<PagedPrefillTreeMaskFunc> f_attention_prefill_with_tree_mask_paged_kv,
std::unique_ptr<RaggedPrefillTreeMaskFunc> f_attention_prefill_with_tree_mask,
std::unique_ptr<PagedPrefillFunc> f_mla_prefill, ffi::Array<ffi::Function> f_merge_inplace,
ffi::Function f_split_rotary, ffi::Function f_copy_single_page, ffi::Function f_debug_get_kv)
: page_size_(page_size),
num_layers_(num_layers),
layer_id_begin_offset_(layer_id_begin_offset),
layer_id_end_offset_(layer_id_end_offset),
num_qo_heads_(num_qo_heads),
num_kv_heads_(num_kv_heads),
qk_head_dim_(qk_head_dim),
v_head_dim_(v_head_dim),
num_total_pages_(num_total_pages),
prefill_chunk_size_(prefill_chunk_size),
support_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(),
AttnKind::kMHASliding) != attn_kinds.end()
? false
: support_sliding_window),
support_layer_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(),
AttnKind::kMHASliding) != attn_kinds.end()),
attn_kinds_(std::move(attn_kinds)),
rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? RoPEMode::kInline
: rope_mode),
rotary_scale_(rotary_scale),
rotary_theta_(rotary_theta),
rope_ext_factors_(std::move(rope_ext_factors)),
kv_dtype_(DataType(dtype)),
f_transpose_append_mha_(std::move(f_transpose_append_mha)),
f_transpose_append_mla_(std::move(f_transpose_append_mla)),
f_compact_copy_(std::move(f_compact_copy)),
f_attention_prefill_ragged_(std::move(f_attention_prefill_ragged)),
f_attention_prefill_(std::move(f_attention_prefill)),
f_attention_decode_(std::move(f_attention_decode)),
f_attention_prefill_sliding_window_(std::move(f_attention_prefill_sliding_window)),
f_attention_decode_sliding_window_(std::move(f_attention_decode_sliding_window)),
f_attention_prefill_with_tree_mask_paged_kv_(
std::move(f_attention_prefill_with_tree_mask_paged_kv)),
f_attention_prefill_with_tree_mask_(std::move(f_attention_prefill_with_tree_mask)),
f_mla_prefill_(std::move(f_mla_prefill)),
f_merge_inplace_(std::move(f_merge_inplace)),
f_split_rotary_(std::move(f_split_rotary)),
f_copy_single_page_(std::move(f_copy_single_page)),
f_debug_get_kv_(std::move(f_debug_get_kv)),
device_(device) {
// Note: For MLA, sliding window and disaggregation are disabled for now.
if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMLA) != attn_kinds_.end()) {
TVM_FFI_ICHECK(!support_sliding_window_) << "Sliding window not supported yet for MLA";
TVM_FFI_ICHECK(!enable_kv_transfer) << "KV transfer not supported yet for MLA";
}
pages_.reserve(num_layers);
if (enable_kv_transfer) {
// For now, KV transfer only supports MHA.
for (AttnKind attn_kind : attn_kinds_) {
TVM_FFI_ICHECK(attn_kind == AttnKind::kMHA);
}
const auto f_nvshmem_init =
tvm::ffi::Function::GetGlobal("runtime.disco.nvshmem.init_nvshmem");
TVM_FFI_ICHECK(f_nvshmem_init.has_value())
<< "NVSHMEM is not enabled. Please make sure NVSHMEM is enabled when compiling TVM.";
const auto f_nvshmem_empty = tvm::ffi::Function::GetGlobal("runtime.disco.nvshmem.empty");
TVM_FFI_ICHECK(f_nvshmem_empty.has_value());
nvshmem_pages_ =
(*f_nvshmem_empty)(
ffi::Shape({num_layers, num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}),
dtype, device)
.cast<Tensor>();
for (int i = 0; i < num_layers; ++i) {
pages_.push_back(nvshmem_pages_.CreateView(
{num_total_pages_, 2, num_kv_heads_, page_size_, qk_head_dim_}, nvshmem_pages_->dtype,
i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ * qk_head_dim_ *
nvshmem_pages_.DataType().bytes()));
}
const auto f_transfer_kv_ptr = tvm::ffi::Function::GetGlobal("nvshmem.KVTransfer");
const auto f_transfer_kv_page_to_page_ptr =
tvm::ffi::Function::GetGlobal("nvshmem.KVTransferPageToPage");
TVM_FFI_ICHECK(f_transfer_kv_ptr.has_value());
TVM_FFI_ICHECK(f_transfer_kv_page_to_page_ptr.has_value());
f_transfer_kv_ = *f_transfer_kv_ptr;
f_transfer_kv_page_to_page_ = *f_transfer_kv_page_to_page_ptr;
} else {
for (int i = 0; i < num_layers; ++i) {
ffi::Shape kv_cache_shape =
GetKVCacheShape(attn_kinds_[layer_id_begin_offset_ + i], num_total_pages,
reserved_num_seqs, num_kv_heads, page_size, qk_head_dim, v_head_dim);
pages_.push_back(Tensor::Empty(kv_cache_shape, dtype, device));
}
}
// Allocate the host memory.
Device preferred_host_device = GetPreferredHostDevice(device);
for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
qo_indptr_on_depths_host_.push_back(
HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device));
page_indptr_on_depths_host_.push_back(
HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device));
page_indices_on_depths_host_.push_back(
HostMemoryVector(num_total_pages, dtype_aux_, preferred_host_device));
page_indptr_sliding_window_on_depths_host_.push_back(
HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device));
page_indices_sliding_window_on_depths_host_.push_back(
HostMemoryVector(num_total_pages, dtype_aux_, preferred_host_device));
last_page_len_on_depths_host_.push_back(
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device));
sliding_window_offset_on_depths_host_.push_back(
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device));
sink_size_on_depths_host_.push_back(
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device));
k_rope_pos_offset_on_depths_host_.push_back(
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device));
k_rope_pos_offset_sliding_window_on_depths_host_.push_back(
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device));
tree_attn_mask_host_.push_back(HostMemoryVector(kTreeAttnMaxTreeSize * 2 * reserved_num_seqs,
dtype_aux_, preferred_host_device));
tree_attn_mn_indptr_host_.push_back(
HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device));
}
k_ragged_rope_pos_offset_host_ =
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device);
q_rope_position_map_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
append_position_map_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
kv_transfer_remote_position_map_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
kv_transfer_recver_id_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
kv_transfer_page_to_page_local_position_map_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
kv_transfer_page_to_page_remote_position_map_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
kv_transfer_page_to_page_recver_id_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
cur_append_lengths_indptr_host_ =
HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device);
commit_copy_length_indptr_host_ =
HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device);
commit_copy_src_pos_in_page_table_host_ =
HostMemoryVector(std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size),
dtype_aux_, preferred_host_device);
commit_copy_dst_pos_in_page_table_host_ =
HostMemoryVector(std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size),
dtype_aux_, preferred_host_device);
for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
if (NeedKernelBeginForward()) {
temp_int_attn_workspace_.push_back(
Tensor::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device));
temp_int_pinned_attn_workspace_.push_back(Tensor::Empty(
{kIntAttnWorkspaceByte}, DataType::UInt(8), GetPreferredHostDevice(device)));
}
qo_indptr_on_depths_view_.push_back(Tensor());
page_indptr_on_depths_view_.push_back(Tensor());
page_indices_on_depths_view_.push_back(Tensor());
page_indptr_sliding_window_on_depths_view_.push_back(Tensor());
page_indices_sliding_window_on_depths_view_.push_back(Tensor());
length_info_on_depths_view_.push_back(Tensor());
layer_sliding_window_length_info_on_depths_view_.push_back(Tensor());
k_rope_pos_offset_view_.push_back(Tensor());
k_rope_pos_offset_sliding_window_view_.push_back(Tensor());
tree_attn_mask_view_.push_back(Tensor());
tree_attn_mn_indptr_view_.push_back(Tensor());
is_chain_on_depths_.push_back(true);
}
// Additional workspace for the "prefill with ragged kv" kernel.
if (NeedKernelBeginForward()) {
temp_int_attn_workspace_.push_back(
Tensor::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device));
temp_int_pinned_attn_workspace_.push_back(Tensor::Empty(
{kIntAttnWorkspaceByte}, DataType::UInt(8), GetPreferredHostDevice(device)));
temp_float_attn_workspace_ =
Tensor::Empty({kFloatAttnWorkspaceByte}, DataType::UInt(8), device);
}
if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHA) != attn_kinds_.end()) {
temp_attn_q_device_ =
Tensor::Empty({prefill_chunk_size_, num_qo_heads, qk_head_dim}, dtype, device);
temp_attn_k_device_ =
Tensor::Empty({prefill_chunk_size_, num_kv_heads, qk_head_dim}, dtype, device);
temp_attn_v_device_ =
Tensor::Empty({prefill_chunk_size_, num_kv_heads, v_head_dim}, dtype, device);
}
temp_attn_output_device_ =
Tensor::Empty({prefill_chunk_size_, num_qo_heads, v_head_dim}, dtype, device);
temp_attn_lse_device_ =
Tensor::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device);
merged_attn_lse_device_ =
Tensor::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device);
for (int64_t page_id = num_total_pages - 1; page_id >= 0; --page_id) {
free_page_ids_.push_back(page_id);
}
// If the device is CUDA/ROCm, we create a standalone copy stream, in
// purpose to hide the latency of auxiliary stream copy.
if (device.device_type == DLDeviceType::kDLCUDA ||
device.device_type == DLDeviceType::kDLROCM) {
// The compute stream is the default stream.
compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device);
copy_stream_ = DeviceAPI::Get(device)->CreateStream(device);
kv_transfer_stream_ = DeviceAPI::Get(device)->CreateStream(device);
}
// Create the auxiliary data manager for attention.
// We only use the merged aux data for CUDA, since direct pointer
// operations may have issues on other platforms.
if (device_.device_type == DLDeviceType::kDLCUDA ||
device_.device_type == DLDeviceType::kDLCPU) {
aux_data_manager_ = std::make_unique<CachedPagedKVCacheAuxDataManager>(
reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device,
preferred_host_device, copy_stream_);
} else {
aux_data_manager_ = std::make_unique<PlainPagedKVCacheAuxDataManager>(
reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device,
preferred_host_device, copy_stream_);
}
// Right now only the "normal" RoPE mode supports the RoPE extention factors.
if (rope_ext_factors_.defined()) {
TVM_FFI_ICHECK(rope_mode_ == RoPEMode::kNormal)
<< "The RoPE mode must be normal to support RoPE extension factors.";
}
}
~PagedAttentionKVCacheObj() {
// Free the copy stream if defined.
if (copy_stream_ != nullptr) {
DeviceAPI::Get(device_)->FreeStream(device_, copy_stream_);
}
if (kv_transfer_stream_ != nullptr) {
DeviceAPI::Get(device_)->FreeStream(device_, kv_transfer_stream_);
}
}
/*! \brief Reset the KV cache. */
void Clear() final {
seq_map_.clear();
free_page_ids_.clear();
for (int64_t page_id = num_total_pages_ - 1; page_id >= 0; --page_id) {
free_page_ids_.push_back(page_id);
}
global_block_pool_.clear();
free_block_idx_.clear();
dirty_aux_data_device_ = false;
}
/************** Sequence Management **************/
void AddSequence(int64_t seq_id) final {
TVM_FFI_ICHECK(seq_map_.find(seq_id) == seq_map_.end())
<< "The sequence \"" << seq_id << "\" is already in the KV cache.";
int32_t block_idx = GetFreeBlock();
seq_map_.insert({seq_id, Sequence(&global_block_pool_, block_idx)});
dirty_aux_data_device_ = true;
}
void RemoveSequence(int64_t seq_id) final {
auto it = seq_map_.find(seq_id);
TVM_FFI_ICHECK(it != seq_map_.end())
<< "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
int32_t block_idx = it->second.last_block_idx;
// The block should have at least one reference, which comes from the sequence.
TVM_FFI_ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1);
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 1) {
// - Free pages in the last block.
for (int32_t page_id : global_block_pool_[block_idx].page_ids) {
free_page_ids_.push_back(page_id);
}
free_block_idx_.push_back(block_idx);
block_idx = global_block_pool_[block_idx].parent_idx;
}
// - Decrease the external reference of the parent block.
if (block_idx != -1) {
TVM_FFI_ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 1);
--global_block_pool_[block_idx].external_ref_cnt;
}
seq_map_.erase(it);
dirty_aux_data_device_ = true;
}
void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) final {
auto parent_it = seq_map_.find(parent_seq_id);
TVM_FFI_ICHECK(parent_it != seq_map_.end())
<< "The parent sequence \"" << parent_seq_id << "\" cannot be found in KV cache.";
TVM_FFI_ICHECK(seq_map_.find(child_seq_id) == seq_map_.end())
<< "The child sequence \"" << child_seq_id << "\" is already in the KV cache.";
TVM_FFI_ICHECK_GE(fork_pos, -1)
<< "The forked position should be non-negative, or -1 for last position as default.";
TVM_FFI_ICHECK_LE(fork_pos, parent_it->second.seq_length)
<< "The forked position should not exceed the total length of parent sequence.";
TVM_FFI_ICHECK(parent_it->second.accepted_indices_committed)
<< "The parent sequence's token tree computed in the last round of forward has not been "
"committed with accepted nodes.";
if (fork_pos == -1) {
fork_pos = parent_it->second.seq_length;
}
if (parent_it->second.sliding_window_size != -1) {
// If forked sequence has been enabled sliding window, check the forked position is within
// sliding window sink size.
const Sequence& seq = parent_it->second;
int32_t sink_size = seq.seq_length - global_block_pool_[seq.last_block_idx].seq_length +
seq.last_block_attn_sink_size;
TVM_FFI_ICHECK_LE(fork_pos, sink_size)
<< "The parent sequence \"" << parent_seq_id
<< "\" is enabled with sliding window and thus only can be forked within sink size = "
<< sink_size << ". But the forked position = " << fork_pos << ".";
}
if (fork_pos == parent_it->second.seq_length && fork_pos % page_size_ == 0 &&
global_block_pool_[parent_it->second.last_block_idx].seq_length > 0) {
// To enable the parent sequence to continue decode after the fork,
// we add a new empty block at the end of the parent sequence.
// So the new decoded KV data will go into the new block.
int32_t new_block_idx = GetFreeBlock();
global_block_pool_[new_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[new_block_idx].parent_idx = parent_it->second.last_block_idx;
global_block_pool_[new_block_idx].external_ref_cnt = 1;
parent_it->second.last_block_idx = new_block_idx;
}
int32_t child_block_idx = GetFreeBlock();
std::vector<int32_t> trace = parent_it->second.GetBlockTrace(global_block_pool_);
int64_t in_block_offset = fork_pos;
for (int32_t forked_block_idx : trace) {
if (forked_block_idx != trace.back()) {
TVM_FFI_ICHECK_GT(global_block_pool_[forked_block_idx].seq_length, 0);
TVM_FFI_ICHECK_EQ(global_block_pool_[forked_block_idx].seq_length % page_size_, 0);
if (global_block_pool_[forked_block_idx].seq_length <= in_block_offset) {
in_block_offset -= global_block_pool_[forked_block_idx].seq_length;
continue;
}
}
int32_t in_page_offset = in_block_offset % page_size_;
int32_t moved_offset = in_block_offset - in_page_offset;
int32_t moved_pages = moved_offset / page_size_;
if (moved_pages == 0) {
// Forked at the first page in block
int32_t parent_block_idx = global_block_pool_[forked_block_idx].parent_idx;
if (parent_block_idx != -1) {
++global_block_pool_[parent_block_idx].external_ref_cnt;
}
// Update child block start position and parent index
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
} else {
// Forked at the second or latter page in block
int32_t parent_block_idx = GetFreeBlock();
// Insert new parent block before forked block and link child block
global_block_pool_[parent_block_idx].parent_idx =
global_block_pool_[forked_block_idx].parent_idx;
global_block_pool_[forked_block_idx].parent_idx = parent_block_idx;
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
global_block_pool_[parent_block_idx].external_ref_cnt = 2;
// Move common leading pages to new parent block
auto first_page = global_block_pool_[forked_block_idx].page_ids.begin();
auto last_page = global_block_pool_[forked_block_idx].page_ids.begin() + moved_pages;
global_block_pool_[parent_block_idx].page_ids = {first_page, last_page};
global_block_pool_[forked_block_idx].page_ids.erase(first_page, last_page);
// Update start position per blocks
global_block_pool_[parent_block_idx].start_pos =
global_block_pool_[forked_block_idx].start_pos;
global_block_pool_[forked_block_idx].start_pos += moved_offset;
// Update in-block sequence length per blocks
global_block_pool_[parent_block_idx].seq_length = moved_offset;
global_block_pool_[forked_block_idx].seq_length -= moved_offset;
// Update sliding window sink size if sliding window is enabled and the forked block is the
// last block
if (parent_it->second.sliding_window_size != -1 &&
forked_block_idx == parent_it->second.last_block_idx) {
TVM_FFI_ICHECK_LE(moved_offset, parent_it->second.last_block_attn_sink_size);
parent_it->second.last_block_attn_sink_size -= moved_offset;
}
}
global_block_pool_[child_block_idx].start_pos = fork_pos - in_page_offset;
global_block_pool_[child_block_idx].seq_length = in_page_offset;
if (in_page_offset > 0) {
// Fork within a page and copy common page to child block partially
int32_t src_page_id = global_block_pool_[forked_block_idx].page_ids[0];
int32_t tgt_page_id = GetFreePage();
global_block_pool_[child_block_idx].page_ids.push_back(tgt_page_id);
CopySinglePage(src_page_id, tgt_page_id, in_page_offset);
}
break;
}
// Create the child sequence with the child block.
seq_map_.insert({child_seq_id, Sequence(&global_block_pool_, child_block_idx)});
dirty_aux_data_device_ = true;
}
void CopySinglePage(int32_t src_page_id, int32_t tgt_page_id, int64_t copy_length) {
if (copy_stream_ != compute_stream_) {
// Set the copy stream for copy.
DeviceAPI::Get(device_)->SetStream(device_, copy_stream_);
}
for (int layer = 0; layer < num_layers_; ++layer) {
Tensor page_layer_view = pages_[layer];
f_copy_single_page_(page_layer_view, src_page_id, tgt_page_id, copy_length);
}
if (copy_stream_ != compute_stream_) {
// Set the compute stream back.
DeviceAPI::Get(device_)->SetStream(device_, compute_stream_);
}
}
void CompactKVCopy() {
int total_copy_length = commit_copy_length_indptr_host_.back();
TVM_FFI_ICHECK_GE(total_copy_length, 0);
if (total_copy_length == 0) {
return;
}
// Copy indptr/src/dst arrays to GPU.
aux_data_manager_->ResetCompactKVAuxDataCopy();
Tensor commit_copy_length_indptr_view =
aux_data_manager_->CopyCommitLengthIndptrAsync(&commit_copy_length_indptr_host_);
Tensor commit_copy_src_dst_pos_in_page_table_view =
aux_data_manager_->CopyCommitSrcDstPosInPageTableAsync(
&commit_copy_src_pos_in_page_table_host_, &commit_copy_dst_pos_in_page_table_host_);
aux_data_manager_->CommitCompactKVAuxDataCopy();
// Invoke the copy kernel on copy stream.
if (copy_stream_ != compute_stream_) {
// Set the copy stream for copy.
DeviceAPI::Get(device_)->SetStream(device_, copy_stream_);
}
TVM_FFI_ICHECK(f_compact_copy_.defined()) << "Function \"f_compact_copy\" is not defined.";
for (int layer = 0; layer < num_layers_; ++layer) {
f_compact_copy_(pages_[layer], commit_copy_length_indptr_view,
commit_copy_src_dst_pos_in_page_table_view, cur_batch_size_);
}
if (copy_stream_ != compute_stream_) {
// Set the compute stream back.
DeviceAPI::Get(device_)->SetStream(device_, compute_stream_);
}
// Note: We do not explicitly synchronize the copy stream here.
// The safety is guaranteed by the synchronization pushed by the next round
// of BeginForward, which also copies auxiliary data structure to GPU.
}
void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size,
int32_t attn_sink_size) final {
// If per layer sliding window exists, enable sliding window for sequence
TVM_FFI_ICHECK(support_sliding_window_ || support_layer_sliding_window_)
<< "The KV cache does not support sliding window.";
auto it = seq_map_.find(seq_id);
TVM_FFI_ICHECK(it != seq_map_.end())
<< "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
TVM_FFI_ICHECK_GE(attn_sink_size, 0)
<< "The specified attention sink size is expected to be non negative";
TVM_FFI_ICHECK_GT(sliding_window_size, 0)
<< "The specified sliding window size should be positive.";
TVM_FFI_ICHECK_LT(attn_sink_size, sliding_window_size)
<< "The attn sink size should be less than the sliding window size.";
// Set the sliding window flag of the sequence.
TVM_FFI_ICHECK_EQ(it->second.sliding_window_size, -1)
<< "A sequence cannot be enabled twice for sliding window.";
// Compute the total length of the prefix blocks of this sequence.
const Block& last_block = global_block_pool_[it->second.last_block_idx];
int32_t prefix_length = it->second.seq_length - last_block.seq_length;
TVM_FFI_ICHECK_GE(prefix_length, 0);
// Since the prefix blocks cannot sliding, they are natural
// attention sinks here. When the prefix length is already
// larger than the specified attn sink size, we do not want to
// introduce more sink. Therefore, we update the given attn sink size.
it->second.last_block_attn_sink_size = std::max(attn_sink_size - prefix_length, 0);
it->second.sliding_window_size = sliding_window_size;
}
void PopN(int64_t seq_id, int32_t n) final {
auto it = seq_map_.find(seq_id);
TVM_FFI_ICHECK(it != seq_map_.end())
<< "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
TVM_FFI_ICHECK_GE(n, 0) << "The length of popping " << n << " cannot be negative.";
TVM_FFI_ICHECK_LE(n, it->second.seq_length)
<< "The sequence only has length " << it->second.seq_length
<< ", while the length of pop is " << n << " which exceeds the whole sequence length.";
if (n == 0) {
return;
}
int32_t block_idx = it->second.last_block_idx;
// The block should have at least one reference, which comes from the sequence.
TVM_FFI_ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1);
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 1) {
if (n > global_block_pool_[block_idx].seq_length) {
n -= global_block_pool_[block_idx].seq_length;
it->second.seq_length -= global_block_pool_[block_idx].seq_length;
for (int32_t page_id : global_block_pool_[block_idx].page_ids) {
free_page_ids_.push_back(page_id);
}
free_block_idx_.push_back(block_idx);
block_idx = global_block_pool_[block_idx].parent_idx;
it->second.last_block_idx = block_idx;
continue;
}
if (n <= global_block_pool_[block_idx].seq_length) {
int64_t cur_npage = global_block_pool_[block_idx].page_ids.size();
int64_t tgt_npage =
(global_block_pool_[block_idx].seq_length - n + page_size_ - 1) / page_size_;
while (cur_npage > tgt_npage) {
free_page_ids_.push_back(global_block_pool_[block_idx].page_ids.back());
global_block_pool_[block_idx].page_ids.pop_back();
--cur_npage;
}
it->second.seq_length -= n;
global_block_pool_[block_idx].seq_length -= n;
n = 0;
break;
}
}
if (n) {
// We use a temporary sequence id for fork.
// This temporary seq id will immediately end its effect outside this function.
int64_t temp_seq_id = -1 - seq_id;
TVM_FFI_ICHECK(seq_map_.find(temp_seq_id) == seq_map_.end());
ForkSequence(seq_id, temp_seq_id, it->second.seq_length - n);
TVM_FFI_ICHECK(seq_map_.find(temp_seq_id) != seq_map_.end());
RemoveSequence(seq_id);
TVM_FFI_ICHECK(seq_map_.find(seq_id) == seq_map_.end());
auto it = seq_map_.find(temp_seq_id);
seq_map_.insert({seq_id, it->second});
seq_map_.erase(temp_seq_id);
}
dirty_aux_data_device_ = true;
}
/************** Raw Info Query **************/
bool Empty() const final {
return seq_map_.empty() && //
free_block_idx_.size() == global_block_pool_.size() && //
free_page_ids_.size() == static_cast<size_t>(num_total_pages_);
}
int32_t GetNumAvailablePages() const final { return free_page_ids_.size(); }
int32_t GetTotalSequenceLength() const final {
int32_t total_seq_len = 0;
for (const auto& it : seq_map_) {
total_seq_len += it.second.seq_length;
}
return total_seq_len;
}
/************** Attention **************/
void BeginForward(const ffi::Shape& seq_ids, const ffi::Shape& append_lengths,
const ffi::Optional<ffi::Shape>& opt_token_tree_parent_ptr) final {
// Note: MLA does not supported tree attention for now.
if (attn_kinds_[0] == AttnKind::kMLA) {
TVM_FFI_ICHECK(!opt_token_tree_parent_ptr.defined())
<< "Tree attention is not supported yet for MLA";
}
TVM_FFI_ICHECK_EQ(seq_ids.size(), append_lengths.size())
<< "The seq_ids size (" << seq_ids.size() << ") and append_lengths size ("
<< append_lengths.size() << ") mismatch.";
cur_batch_size_ = seq_ids.size();
cur_seq_ids_ = seq_ids;
cur_append_lengths_ = append_lengths;
// - Collect sequence/block/page information for attention.
std::vector<Sequence*> sequences;
std::vector<int32_t> last_block_length_before_append;
is_decode_request_ = true;
sequences.reserve(cur_batch_size_);
last_block_length_before_append.reserve(cur_batch_size_);
k_ragged_rope_pos_offset_host_.clear();
for (int i = 0; i < cur_batch_size_; ++i) {
auto it = seq_map_.find(seq_ids[i]);
TVM_FFI_ICHECK(it != seq_map_.end())
<< "The sequence \"" << seq_ids[i] << "\" cannot be found in KV cache.";
sequences.push_back(&it->second);
last_block_length_before_append.push_back(
global_block_pool_[it->second.last_block_idx].seq_length);
int k_rope_offset = it->second.seq_length;
if (!it->second.accepted_indices_committed) {
int tree_size = static_cast<int>(it->second.token_tree_parent_ptr.size());
k_rope_offset -= tree_size;
}
k_ragged_rope_pos_offset_host_.push_back(k_rope_offset);
it->second.seq_length += append_lengths[i];
if (append_lengths[i] != 1) {
is_decode_request_ = false;
}
}
auto [block_ids_on_depths, trailing_blocks] =
GetBlockIdsOnDepth(sequences, global_block_pool_, cur_batch_size_);
num_depths_ =
std::min(static_cast<int>(block_ids_on_depths.size()), kPagedKVCacheMaxBlockDepth);
TVM_FFI_ICHECK_LE(num_depths_, kPagedKVCacheMaxBlockDepth);
std::vector<std::vector<std::pair<int32_t, int32_t>>> chunked_block_ids_arr;
chunked_block_ids_arr.reserve(num_depths_);
use_decode_kernel_.clear();
for (int d = 0; d < num_depths_; ++d) {
// We force the blocks at maximum depth not to coalesce, so that it can be concatenated with
// trailing exceeding blocks.
auto [chunked_block_ids, use_decode_kernel] = GetChunkedBlockIds(
block_ids_on_depths[d], /*enable_coalesce=*/d != kPagedKVCacheMaxBlockDepth - 1,
cur_append_lengths_, global_block_pool_, is_decode_request_);
chunked_block_ids_arr.push_back(chunked_block_ids);
use_decode_kernel_.push_back(use_decode_kernel);
}
if (num_depths_ == kPagedKVCacheMaxBlockDepth) {
// Since we force the blocks at maximum depth not to coalesce, the output blocks at maximum
// depth must have the same size as current batch.
TVM_FFI_ICHECK_EQ(chunked_block_ids_arr[num_depths_ - 1].size(), cur_batch_size_);
}
append_before_attn_ = !support_sliding_window_ && use_decode_kernel_.back();
if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) {
// When GQA group size is at least 4 and FlashInfer is enabled,
// we always use prefill kernel for better performance.
// Note: For MLA, we always use prefill kernel, so values in `use_decode_kernel` will
// be ignored for MLA.
std::fill(use_decode_kernel_.begin(), use_decode_kernel_.end(), /*value=*/false);
}
bool has_previous_tree =
std::any_of(sequences.begin(), sequences.end(),
[](const Sequence* sequence) { return !sequence->accepted_indices_committed; });
if (has_previous_tree) {
append_before_attn_ = true;
}
// - Check token tree validity and process the token tree.
if (opt_token_tree_parent_ptr.defined()) {
TVM_FFI_ICHECK(!support_sliding_window_) << "Tree attention does not support sliding window.";
TVM_FFI_ICHECK(rope_mode_ != RoPEMode::kInline)
<< "Tree attention does not support inline RoPE mode.";
ConstructTokenTreeMask(sequences, opt_token_tree_parent_ptr.value(), block_ids_on_depths,
trailing_blocks);
} else {
// The input batch does not form trees. So each sequence in the batch
// is required to have all past accepted tokens committed.
for (int i = 0; i < cur_batch_size_; ++i) {
Sequence* sequence = sequences[i];
TVM_FFI_ICHECK(sequence->accepted_indices_committed)
<< "The input batch does not form a tree, in which case the sequences in the input "
"batch are expected to have their accepted tokens token tree nodes committed. "
"Please invoke CommitAcceptedTokenTreeNodes for sequence "
<< seq_ids[i];
sequence->is_chain = true;
sequence->token_tree_parent_ptr.clear();
sequence->token_tree_node_depths.clear();
}
std::fill(is_chain_on_depths_.begin(), is_chain_on_depths_.end(), true);
}
if (append_before_attn_) {
// Right now we use different kernels when depth is 1 or not 1.
// For the case where maximum depth is 1, we create the auxiliary
// data structure with regard to the page table after appending.
for (int i = 0; i < cur_batch_size_; ++i) {
ReserveAppendLengthInSeq(sequences[i], append_lengths[i]);
}
}
for (int d = 0; d < num_depths_; ++d) {
HostMemoryVector& qo_indptr_h = qo_indptr_on_depths_host_[d];
HostMemoryVector& page_indptr_h = page_indptr_on_depths_host_[d];
HostMemoryVector& page_indices_h = page_indices_on_depths_host_[d];
HostMemoryVector& page_indptr_sliding_window_h =
page_indptr_sliding_window_on_depths_host_[d];
HostMemoryVector& page_indices_sliding_window_h =
page_indices_sliding_window_on_depths_host_[d];
HostMemoryVector& last_page_len_h = last_page_len_on_depths_host_[d];
HostMemoryVector& sliding_window_offset_h = sliding_window_offset_on_depths_host_[d];
HostMemoryVector& sink_size_h = sink_size_on_depths_host_[d];
HostMemoryVector& k_rope_pos_offset_h = k_rope_pos_offset_on_depths_host_[d];
HostMemoryVector& k_rope_pos_offset_sliding_window_h =
k_rope_pos_offset_sliding_window_on_depths_host_[d];
qo_indptr_h.clear();
page_indptr_h.clear();
page_indices_h.clear();
page_indptr_sliding_window_h.clear();
page_indices_sliding_window_h.clear();
last_page_len_h.clear();
sliding_window_offset_h.clear();
sink_size_h.clear();
k_rope_pos_offset_h.clear();
k_rope_pos_offset_sliding_window_h.clear();
qo_indptr_h.push_back(0);
page_indptr_h.push_back(0);
page_indptr_sliding_window_h.push_back(0);
for (int i = 0; i < static_cast<int>(chunked_block_ids_arr[d].size()); ++i) {
const auto& [block_id, chunk_append_length] = chunked_block_ids_arr[d][i];
qo_indptr_h.push_back(qo_indptr_h.back() + chunk_append_length);
if (block_id == -1) {
page_indptr_h.push_back(page_indptr_h.back());
page_indptr_sliding_window_h.push_back(page_indptr_sliding_window_h.back());
last_page_len_h.push_back(0);
sliding_window_offset_h.push_back(0);
sink_size_h.push_back(0);
k_rope_pos_offset_h.push_back(0);
k_rope_pos_offset_sliding_window_h.push_back(0);
} else {
if (d < kPagedKVCacheMaxBlockDepth - 1) {
// Blocks not at maximum depth
const Block& block = global_block_pool_[block_id];
page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size());
for (int32_t page_id : block.page_ids) {
page_indices_h.push_back(page_id);
// Do the same for page_indices_sliding_window
}
// For sliding window, the first page and last page will both be partially used
page_indptr_sliding_window_h.push_back(
page_indptr_sliding_window_h.back() +
std::min(static_cast<int32_t>(block.page_ids.size()),
static_cast<int32_t>(1024 / page_size_ +
(block.seq_length % page_size_ ? 1 : 0))));
for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back();
i < static_cast<int32_t>(page_indices_h.size()); i++) {
page_indices_sliding_window_h.push_back(page_indices_h[i]);
}
// set up the page indices properly by choosing the last (sliding_window_size /
// page_size_) pages (at most)
last_page_len_h.push_back(
block.seq_length == 0
? 0
: (block.seq_length - block.sink_length + block.sliding_window_offset - 1) %
page_size_ +
1);
if (support_layer_sliding_window_) {
if (block.seq_length < 1024) {
sliding_window_offset_h.push_back(0);
} else {
sliding_window_offset_h.push_back(block.seq_length % page_size_);
}
} else {
sliding_window_offset_h.push_back(block.sliding_window_offset);
}
sink_size_h.push_back(block.sink_length);
k_rope_pos_offset_h.push_back(block.start_pos);
// If sliding window, we need to calculate the positional offset
if (support_layer_sliding_window_) {
k_rope_pos_offset_sliding_window_h.push_back(
std::max(0, block.start_pos + block.seq_length - 1024));
}
} else {
// Blocks at maximum depth
const Block& block = global_block_pool_[block_id];
int32_t num_pages = static_cast<int32_t>(block.page_ids.size());
int32_t total_seq_length = static_cast<int32_t>(block.seq_length);
int32_t last_block_id = block_id;
for (int32_t page_id : block.page_ids) {
page_indices_h.push_back(page_id);
}
for (int32_t id : trailing_blocks[i]) {
// Collect trailing blocks if available
const Block& block = global_block_pool_[id];
for (int32_t page_id : block.page_ids) {
page_indices_h.push_back(page_id);
}
num_pages += block.page_ids.size();
total_seq_length += block.seq_length;
last_block_id = id;
}
page_indptr_h.push_back(page_indptr_h.back() + num_pages);
page_indptr_sliding_window_h.push_back(
page_indptr_sliding_window_h.back() +
std::min(static_cast<int32_t>(block.page_ids.size()),
static_cast<int32_t>(1024 / page_size_ +
(block.seq_length % page_size_ ? 1 : 0))));
for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back();
i < static_cast<int32_t>(page_indices_h.size()); i++) {
page_indices_sliding_window_h.push_back(page_indices_h[i]);
}
const Block& last_block = global_block_pool_[last_block_id];
last_page_len_h.push_back(total_seq_length == 0
? 0
: (total_seq_length - last_block.sink_length +
last_block.sliding_window_offset - 1) %
page_size_ +
1);
if (support_layer_sliding_window_) {
if (last_block.seq_length < 1024) {
sliding_window_offset_h.push_back(0);
} else {
sliding_window_offset_h.push_back(last_block.seq_length % page_size_);
}
} else {
sliding_window_offset_h.push_back(last_block.sliding_window_offset);
}
sink_size_h.push_back(last_block.sink_length);
k_rope_pos_offset_h.push_back(block.start_pos);
if (support_layer_sliding_window_) {
k_rope_pos_offset_sliding_window_h.push_back(
std::max(0, block.start_pos + block.seq_length - 1024));
}
}
}
}
}
if (!append_before_attn_) {
// Right now we use different kernels when depth is 1 or not 1.
// For the case where maximum depth is not 1, we create the auxiliary
// data structure with regard to the page table before appending.
for (int i = 0; i < cur_batch_size_; ++i) {
ReserveAppendLengthInSeq(sequences[i], append_lengths[i]);
}
}
// Map each the token position in the input batch to the position
// in the global KV cache. The mapping is used in when appending k/v values.
q_rope_position_map_host_.clear();
append_position_map_host_.clear();
kv_transfer_remote_position_map_host_.clear();
kv_transfer_recver_id_host_.clear();
kv_transfer_page_to_page_local_position_map_host_.clear();
kv_transfer_page_to_page_remote_position_map_host_.clear();
kv_transfer_page_to_page_recver_id_host_.clear();
transfer_kv_ = false;
page_to_page_transfer_kv_ = false;
for (int i = 0; i < cur_batch_size_; ++i) {
int64_t append_length = append_lengths[i];
const Block& block = global_block_pool_[sequences[i]->last_block_idx];
for (int64_t pos = 0; pos < append_length; ++pos) {
if (sequences[i]->token_tree_node_depths.empty()) {
q_rope_position_map_host_.push_back(k_ragged_rope_pos_offset_host_[i] + pos);
} else {
int64_t offset_in_tree =
static_cast<int64_t>(sequences[i]->token_tree_parent_ptr.size()) - append_length;
TVM_FFI_ICHECK_GE(offset_in_tree, 0);
q_rope_position_map_host_.push_back(
k_ragged_rope_pos_offset_host_[i] +
sequences[i]->token_tree_node_depths[offset_in_tree + pos]);
}
int32_t pos_in_block = block.seq_length - append_length + pos;
if (last_block_length_before_append[i] + pos < block.sink_length) {
// The location to write is part of the attention sink.
int32_t offset_in_block = last_block_length_before_append[i] + pos;
append_position_map_host_.push_back(block.page_ids[offset_in_block / page_size_] *
page_size_ +
offset_in_block % page_size_);
} else if (pos_in_block < block.sink_length) {
// The location to write is pinned by attn sink before the append.
// Therefore we cannot write into the location.
append_position_map_host_.push_back(-1);
} else {
// The location to write is in the sliding window.
int32_t offset_in_block = pos_in_block - block.sink_length + block.sliding_window_offset;
append_position_map_host_.push_back(block.page_ids[offset_in_block / page_size_] *
page_size_ +
offset_in_block % page_size_);
}
int64_t pos_in_seq = sequences[i]->seq_length - append_length + pos;
int64_t seq_send_start = sequences[i]->kv_transfer_metadata.start;
if (pos_in_seq < seq_send_start) {
kv_transfer_remote_position_map_host_.push_back(-1);
kv_transfer_recver_id_host_.push_back(-1);
} else {
transfer_kv_ = true;
kv_transfer_remote_position_map_host_.push_back(
sequences[i]->kv_transfer_metadata.remote_position_map[pos_in_seq - seq_send_start]);
kv_transfer_recver_id_host_.push_back(
sequences[i]->kv_transfer_metadata.recver_pe_offset);
}
}
if (!sequences[i]->kv_transfer_metadata.local_position_map.empty()) {
page_to_page_transfer_kv_ = true;
for (int pos = 0;
pos < static_cast<int>(sequences[i]->kv_transfer_metadata.local_position_map.size());
++pos) {
kv_transfer_page_to_page_local_position_map_host_.push_back(
sequences[i]->kv_transfer_metadata.local_position_map[pos]);
kv_transfer_page_to_page_remote_position_map_host_.push_back(
sequences[i]->kv_transfer_metadata.remote_position_map[pos]);
kv_transfer_page_to_page_recver_id_host_.push_back(
sequences[i]->kv_transfer_metadata.recver_pe_offset);
}
sequences[i]->kv_transfer_metadata.local_position_map.clear();
}
}
}
void EndForward() final {
if (kv_transfer_stream_ != nullptr) {
DeviceAPI::Get(device_)->SyncStreamFromTo(device_, kv_transfer_stream_, compute_stream_);
}
}
ffi::Shape DisaggPrepareRecv(int64_t seq_id, int append_length) final {
// No CPU to GPU copy is needed.
// Essentially we
// (step 1.) redirect the preparation to BeginForward.
BeginForward({seq_id}, {append_length}, /*opt_token_tree_parent_ptr=*/std::nullopt);
// (step 2.) fetch the append_position_map, compress and return.
// Compression format: [n, begin_1, length_1, begin_2, length_2, ..., begin_n, length_n]
// The compressed format will be decompressed to:
// [begin_1, begin_1+1, ..., begin_1+length_1-1, ..., begin_n, ..., begin_n+length_n-1]
TVM_FFI_ICHECK_EQ(append_position_map_host_.size(), append_length);
std::vector<int64_t> compressed_append_pos_map{/*num_segments=*/1,
append_position_map_host_[0]};
for (int i = 1; i < append_length; ++i) {
if (append_position_map_host_[i] != append_position_map_host_[i - 1] + 1) {
// Terminate the current segment.
compressed_append_pos_map.push_back(append_position_map_host_[i - 1] -
compressed_append_pos_map.back() + 1);
// Start a new segment.
++compressed_append_pos_map[0];
compressed_append_pos_map.push_back(append_position_map_host_[i]);
}
}
// Terminate the last segment.
compressed_append_pos_map.push_back(append_position_map_host_.back() -
compressed_append_pos_map.back() + 1);
// The compressed array size should be "num_segments * 2 + 1".
TVM_FFI_ICHECK_EQ(compressed_append_pos_map.size(), compressed_append_pos_map[0] * 2 + 1);
return ffi::Shape{compressed_append_pos_map};
}
void DisaggMarkSend(int64_t seq_id, int64_t begin,
const ffi::Shape& compressed_remote_position_map, int32_t recver_pe_offset) {
TVM_FFI_ICHECK(f_transfer_kv_.defined());
auto it = seq_map_.find(seq_id);
TVM_FFI_ICHECK(it != seq_map_.end())
<< "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
Sequence* sequence = &it->second;
sequence->kv_transfer_metadata.start = begin;
int nsegments = compressed_remote_position_map[0];
sequence->kv_transfer_metadata.remote_position_map.clear();
for (int i = 0; i < nsegments; ++i) {
int begin = compressed_remote_position_map[2 * i + 1];
int length = compressed_remote_position_map[2 * i + 2];
for (int j = 0; j < length; ++j) {
sequence->kv_transfer_metadata.remote_position_map.push_back(begin + j);
}
}
sequence->kv_transfer_metadata.recver_pe_offset = recver_pe_offset;
sequence->kv_transfer_metadata.local_position_map.clear();
if (begin >= sequence->seq_length) {
return;
}
// Need to send existing KV.
TVM_FFI_ICHECK_GT(static_cast<int>(sequence->kv_transfer_metadata.remote_position_map.size()),
sequence->seq_length - begin)
<< "Need at least one token to prefill";
std::vector<int32_t> trace = sequence->GetBlockTrace(global_block_pool_);
sequence->kv_transfer_metadata.local_position_map.reserve(sequence->seq_length - begin);
bool done = false;
for (auto it_block_id = trace.rbegin(); it_block_id != trace.rend(); ++it_block_id) {
const Block& block = global_block_pool_[*it_block_id];
for (int i = block.seq_length - 1; i >= 0; --i) {
int32_t offset =
i < block.sink_length ? i : i - block.sink_length + block.sliding_window_offset;
int page_id = block.page_ids[offset / page_size_];
int page_offset = offset % page_size_;
sequence->kv_transfer_metadata.local_position_map.push_back(page_id * page_size_ +
page_offset);
if (static_cast<int>(sequence->kv_transfer_metadata.local_position_map.size()) ==
sequence->seq_length - begin) {
done = true;
break;
}
}
if (done) {
break;
}
}
std::reverse(sequence->kv_transfer_metadata.local_position_map.begin(),
sequence->kv_transfer_metadata.local_position_map.end());
}
void AttentionWithFusedQKV(int64_t layer_id, Tensor qkv_data, ffi::Optional<Tensor> mask,
Tensor o_data, double sm_scale) final {
// Part 1. Shape and dtype check.
int64_t local_layer_id = layer_id - layer_id_begin_offset_;
TVM_FFI_ICHECK_GE(local_layer_id, 0);
TVM_FFI_ICHECK_LT(local_layer_id, num_layers_);
Tensor pages = pages_[local_layer_id];
TVM_FFI_ICHECK(qkv_data.DataType() == pages.DataType());
TVM_FFI_ICHECK(o_data.DataType() == pages.DataType());
TVM_FFI_ICHECK(attn_kinds_[layer_id] == AttnKind::kMHA ||
attn_kinds_[layer_id] == AttnKind::kMHASliding);
// qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, qk_head_dim)
// o_data: (num_total_length, num_qo_heads, qk_head_dim)
TVM_FFI_ICHECK_EQ(qkv_data->ndim, 3);
TVM_FFI_ICHECK_EQ(o_data->ndim, 3);
for (int dim = 0; dim < 3; ++dim) {
if (dim == 1) {
TVM_FFI_ICHECK_EQ(qkv_data->shape[1], num_qo_heads_ + 2 * num_kv_heads_);
TVM_FFI_ICHECK_EQ(o_data->shape[1], num_qo_heads_);
} else {
TVM_FFI_ICHECK_EQ(o_data->shape[dim], qkv_data->shape[dim]);
}
}
TVM_FFI_ICHECK_EQ(qkv_data->shape[2], qk_head_dim_);
int64_t total_seq_length = 0;
for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) {
total_seq_length += cur_append_lengths_[seq_id];
}
TVM_FFI_ICHECK_LE(total_seq_length, qkv_data->shape[0]);
// Sync the copy stream and the compute stream.
ComputeStreamWaitForCopyStream();
// The auxiliary data structure on device must have been synchronized.
TVM_FFI_ICHECK(!dirty_aux_data_device_);
Tensor q_data = temp_attn_q_device_.CreateView({total_seq_length, num_qo_heads_, qk_head_dim_},
qkv_data->dtype);
Tensor k_data = temp_attn_k_device_.CreateView({total_seq_length, num_kv_heads_, qk_head_dim_},
qkv_data->dtype);
Tensor v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, qk_head_dim_},
qkv_data->dtype);
Tensor qkv_data_view = qkv_data;
Tensor o_data_view = o_data;
if (total_seq_length != qkv_data->shape[0]) {
qkv_data_view = qkv_data.CreateView(
{total_seq_length, qkv_data->shape[1], qkv_data->shape[2]}, qkv_data->dtype);
o_data_view =
o_data.CreateView({total_seq_length, num_qo_heads_, qk_head_dim_}, qkv_data->dtype);
}
// Part 2. Split fused qkv and apply rotary embedding to q/k data.
if (transfer_kv_) {
// The the compute stream needs to wait for the KV transfer stream.
DeviceAPI::Get(device_)->SyncStreamFromTo(device_, kv_transfer_stream_, compute_stream_);
}
if (!rope_ext_factors_.defined()) {
f_split_rotary_(qkv_data_view, q_rope_position_map_view_, q_data, k_data, v_data,
static_cast<int>(rope_mode_ == RoPEMode::kNormal));
} else {
f_split_rotary_(qkv_data_view, q_rope_position_map_view_, q_data, k_data, v_data,
rope_ext_factors_.value());
}
// Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set.
TVM_FFI_ICHECK(f_transpose_append_mha_.defined());
if (append_before_attn_) {
f_transpose_append_mha_.value()(pages_[local_layer_id], k_data, v_data,
append_position_map_view_);
}
// Part 4: KV transfer
if (page_to_page_transfer_kv_) {
DeviceAPI::Get(device_)->SyncStreamFromTo(device_, copy_stream_, kv_transfer_stream_);
// FIXME: if the sender and recver's PP/TP degree do not match, we will need to first
// get the view of remote pages, and then take the specific remote layer.
// The KV transfer stream nees to wait for the compute stream.
f_transfer_kv_page_to_page_.value()(pages_[local_layer_id], pages_[local_layer_id],
kv_transfer_page_to_page_remote_position_map_view_,
kv_transfer_page_to_page_local_position_map_view_,
kv_transfer_page_to_page_recver_id_view_,
kv_transfer_stream_);
}
if (transfer_kv_) {
// FIXME: if the sender and recver's PP/TP degree do not match, we will need to first
// get the view of remote pages, and then take the specific remote layer.
// The KV transfer stream nees to wait for the compute stream.
DeviceAPI::Get(device_)->SyncStreamFromTo(device_, compute_stream_, kv_transfer_stream_);
f_transfer_kv_.value()(pages_[local_layer_id], k_data, v_data,
kv_transfer_remote_position_map_view_, kv_transfer_recver_id_view_,
kv_transfer_stream_);
}
// Part 5: perform attention
AttentionInternal(layer_id, q_data, k_data, v_data, o_data_view, sm_scale);
// Part 6. Append k/v data to kv-cache if flag "append_before_attn" is not set.
if (!append_before_attn_) {
f_transpose_append_mha_.value()(pages_[local_layer_id], k_data, v_data,
append_position_map_view_);
}
}
void SelfAttention(int64_t layer_id, Tensor q_data, Tensor k_data, Tensor v_data, Tensor o_data,
Tensor lse_data, double sm_scale) final {
// Shape and dtype check.
int64_t local_layer_id = layer_id - layer_id_begin_offset_;
TVM_FFI_ICHECK_GE(local_layer_id, 0);
TVM_FFI_ICHECK_LT(local_layer_id, num_layers_);
Tensor pages = pages_[local_layer_id];
TVM_FFI_ICHECK(q_data.DataType() == pages.DataType());
TVM_FFI_ICHECK(k_data.DataType() == pages.DataType());
TVM_FFI_ICHECK(v_data.DataType() == pages.DataType());
TVM_FFI_ICHECK(o_data.DataType() == pages.DataType());
AttnKind attn_kind = attn_kinds_[layer_id];
// q_data: (num_total_length, num_qo_heads, qk_head_dim)
// k_data: (num_total_length, num_kv_heads, qk_head_dim)
// v_data: (num_total_length, num_kv_heads, v_head_dim)
// o_data: (num_total_length, num_qo_heads, v_head_dim)
int64_t total_seq_length = 0;
for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) {
total_seq_length += cur_append_lengths_[seq_id];
}
TVM_FFI_ICHECK_EQ(q_data->ndim, 3);
TVM_FFI_ICHECK_EQ(k_data->ndim, 3);
TVM_FFI_ICHECK_EQ(v_data->ndim, 3);
TVM_FFI_ICHECK_EQ(o_data->ndim, 3);
TVM_FFI_ICHECK_EQ(q_data->shape[0], total_seq_length);
TVM_FFI_ICHECK_EQ(k_data->shape[0], total_seq_length);
TVM_FFI_ICHECK_EQ(v_data->shape[0], total_seq_length);
TVM_FFI_ICHECK_EQ(o_data->shape[0], total_seq_length);
// Sync the copy stream and the compute stream.
ComputeStreamWaitForCopyStream();
// The auxiliary data structure on device must have been synchronized.
TVM_FFI_ICHECK(!dirty_aux_data_device_);
if (attn_kind == AttnKind::kMHA) {
MHASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale);
} else {
MLASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale);
}
}
void CrossAttention(int64_t layer_id, Tensor q_data, Tensor o_data, Tensor lse_data,
double sm_scale) final {
// Shape and dtype check.
int64_t local_layer_id = layer_id - layer_id_begin_offset_;
TVM_FFI_ICHECK_GE(local_layer_id, 0);
TVM_FFI_ICHECK_LT(local_layer_id, num_layers_);
Tensor pages = pages_[local_layer_id];
TVM_FFI_ICHECK(q_data.DataType() == pages.DataType());
TVM_FFI_ICHECK(o_data.DataType() == pages.DataType());
AttnKind attn_kind = attn_kinds_[layer_id];
// q_data: (num_total_length, num_qo_heads, qk_head_dim)
// o_data: (num_total_length, num_qo_heads, v_head_dim)
int64_t total_seq_length = 0;
for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) {
total_seq_length += cur_append_lengths_[seq_id];
}
TVM_FFI_ICHECK_EQ(q_data->ndim, 3);
TVM_FFI_ICHECK_EQ(o_data->ndim, 3);
TVM_FFI_ICHECK_EQ(q_data->shape[0], total_seq_length);
TVM_FFI_ICHECK_EQ(o_data->shape[0], total_seq_length);
TVM_FFI_ICHECK_EQ(q_data->shape[1], num_qo_heads_);
TVM_FFI_ICHECK_EQ(o_data->shape[1], num_qo_heads_);
TVM_FFI_ICHECK_EQ(q_data->shape[2], qk_head_dim_);
TVM_FFI_ICHECK_EQ(o_data->shape[2], v_head_dim_);
// Sync the copy stream and the compute stream.
ComputeStreamWaitForCopyStream();
// The auxiliary data structure on device must have been synchronized.
TVM_FFI_ICHECK(!dirty_aux_data_device_);
if (attn_kind == AttnKind::kMHA) {
MHACrossAttnInternal(local_layer_id, q_data, o_data, lse_data, sm_scale,
/*is_first_kernel=*/true);
} else {
MLACrossAttnInternal(local_layer_id, q_data, o_data, lse_data, sm_scale);
}
}
void AppendMLAKV(int64_t layer_id, Tensor kv_data) final {
// Shape and dtype check.
int64_t local_layer_id = layer_id - layer_id_begin_offset_;
TVM_FFI_ICHECK_GE(local_layer_id, 0);
TVM_FFI_ICHECK_LT(local_layer_id, num_layers_);
Tensor pages = pages_[local_layer_id];
TVM_FFI_ICHECK(kv_data.DataType() == pages.DataType());
TVM_FFI_ICHECK(attn_kinds_[layer_id] == AttnKind::kMLA);
// kv_data: (num_total_length, qk_head_dim)
TVM_FFI_ICHECK_EQ(kv_data->ndim, 2);
int64_t total_seq_length = 0;
for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) {
total_seq_length += cur_append_lengths_[seq_id];
}
TVM_FFI_ICHECK_LE(kv_data->shape[0], total_seq_length);
TVM_FFI_ICHECK_EQ(kv_data->shape[1], qk_head_dim_);
// Sync the copy stream and the compute stream.
ComputeStreamWaitForCopyStream();
// The auxiliary data structure on device must have been synchronized.
TVM_FFI_ICHECK(!dirty_aux_data_device_);
TVM_FFI_ICHECK(f_transpose_append_mla_.defined());
f_transpose_append_mla_.value()(pages_[local_layer_id], kv_data, append_position_map_view_);
}
ffi::Array<Tensor> MergeAttnOutputInplace(Tensor o_self_attn, Tensor lse_self_attn,
Tensor o_cross_attn, Tensor lse_cross_attn) final {
TVM_FFI_ICHECK_GE(f_merge_inplace_.size(), 2)
<< "The general attention merge function is not defined.";
f_merge_inplace_[1](o_self_attn, lse_self_attn, o_cross_attn, lse_cross_attn);
return {o_self_attn, lse_self_attn};
}
void LinearAttention(int64_t layer_id, Tensor q_data, Tensor k_data, Tensor v_data,
double sm_scale) {
// Todo(ruihang): implement it
}
void CommitAcceptedTokenTreeNodes(const ffi::Shape& seq_ids,
const ffi::Shape& leaf_indices) final {
TVM_FFI_ICHECK_EQ(seq_ids.size(), leaf_indices.size())
<< "The given seq_ids and leaf_indices have different size.";
int num_seq_to_commit = seq_ids.size();
std::vector<Sequence*> sequences;
sequences.reserve(num_seq_to_commit);
bool is_chain = true;
for (int i = 0; i < num_seq_to_commit; ++i) {
auto it = seq_map_.find(seq_ids[i]);
TVM_FFI_ICHECK(it != seq_map_.end())
<< "The sequence \"" << seq_ids[i] << "\" cannot be found in KV cache.";
sequences.push_back(&it->second);
is_chain = it->second.is_chain;
TVM_FFI_ICHECK(leaf_indices[i] == -1 || !it->second.accepted_indices_committed)
<< "The accepted nodes of sequence " << seq_ids[i] << " are already committed.";
TVM_FFI_ICHECK_GE(leaf_indices[i], -1)
<< "Invalid tree index " << leaf_indices[i] << " which is less than -1";
TVM_FFI_ICHECK_LT(leaf_indices[i],
static_cast<int64_t>(it->second.token_tree_parent_ptr.size()))
<< "Invalid tree index " << leaf_indices[i]
<< " which is larger than or equals to the append length "
<< it->second.token_tree_parent_ptr.size() << " of the sequence";
}
if (!is_chain) {
commit_copy_length_indptr_host_.clear();
commit_copy_src_pos_in_page_table_host_.clear();
commit_copy_dst_pos_in_page_table_host_.clear();
commit_copy_length_indptr_host_.push_back(0);
for (int i = 0; i < num_seq_to_commit; ++i) {
if (leaf_indices[i] == -1) {
// No node is accepted. All nodes in the token tree need to be popped.
commit_copy_length_indptr_host_.push_back(commit_copy_length_indptr_host_.back());
continue;
}
// Get the accepted node path on the token tree.
std::vector<int32_t> path_on_tree;
path_on_tree.reserve(sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1);
int node = leaf_indices[i];
while (node != -1) {
path_on_tree.push_back(node);
node = sequences[i]->token_tree_parent_ptr[node];
}
TVM_FFI_ICHECK_EQ(path_on_tree.size(),
sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1);
// Get the destination array (range [0, path_length - 1)) of KV cache copy.
std::vector<int32_t> copy_dst_pos_in_seq;
copy_dst_pos_in_seq.resize(path_on_tree.size());
std::iota(copy_dst_pos_in_seq.rbegin(), copy_dst_pos_in_seq.rend(), /*value=*/0);
// Remove the positions whose KV data do not need copy.
while (!path_on_tree.empty() && path_on_tree.back() == copy_dst_pos_in_seq.back()) {
path_on_tree.pop_back();
copy_dst_pos_in_seq.pop_back();
}
// Reverse the position arrays so that they are in ascending order.
std::reverse(path_on_tree.begin(), path_on_tree.end());
std::reverse(copy_dst_pos_in_seq.begin(), copy_dst_pos_in_seq.end());
// Convert the in-sequence src/dst positions to src/dst positions in page table
// by looking up "append_position_map".
for (int p = 0; p < static_cast<int>(path_on_tree.size()); ++p) {
commit_copy_src_pos_in_page_table_host_.push_back(
append_position_map_host_[cur_append_lengths_indptr_host_[i] + path_on_tree[p]]);
commit_copy_dst_pos_in_page_table_host_.push_back(
append_position_map_host_[cur_append_lengths_indptr_host_[i] +
copy_dst_pos_in_seq[p]]);
}
commit_copy_length_indptr_host_.push_back(commit_copy_length_indptr_host_.back() +
path_on_tree.size());
}
// Compact the KV data for each sequence by copying KV data.
CompactKVCopy();
}
// - Update the KV cache page data structure.
// Note: Function "PopN" only changes the page table structure and does not
// change the KV cache data. Therefore, we can directly use it, since
// we have already launched all copies.
for (int i = 0; i < num_seq_to_commit; ++i) {
int64_t length_to_pop =
cur_append_lengths_[i] -
(leaf_indices[i] != -1 ? (sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1) : 0);
PopN(cur_seq_ids_[i], length_to_pop);
// Reset the sequence states.
sequences[i]->accepted_indices_committed = true;
sequences[i]->token_tree_parent_ptr.clear();
sequences[i]->token_tree_node_depths.clear();
}
}
Tensor GetQueryPositions() final {
// Sync the copy stream and the compute stream.
ComputeStreamWaitForCopyStream();
// The auxiliary data structure on device must have been synchronized.
TVM_FFI_ICHECK(!dirty_aux_data_device_);
return q_rope_position_map_view_;
};
void DebugGetKV(int64_t seq_id, int64_t start_pos, int64_t end_pos, Tensor k_data,
Tensor v_data) final {
TVM_FFI_ICHECK(f_debug_get_kv_.defined())
<< "PageAttentionKVCache requires the `f_debug_get_kv` to be explicitly passed in when "
"initialization. Please construct the KV cache with `f_debug_get_kv`.";
const Sequence& seq = seq_map_.at(seq_id);
TVM_FFI_ICHECK_GE(start_pos, 0)
<< "DebugGetKV does not accept negative start_pos " << start_pos;
TVM_FFI_ICHECK_LE(end_pos, seq.seq_length) << "DebugGetKV does not accept out-of-range end_pos";
TVM_FFI_ICHECK_LT(start_pos, end_pos) << "DebugGetKV does not accept \"start_pos >= end_pos\"";
// k/v_data: (num_layers, seq_length, num_kv_heads, qk_head_dim)
static constexpr const char* error_msg =
"DebugGetKV expects the k_data in layout (num_layers, seq_length, num_kv_heads, "
"qk_head_dim).";
std::vector<Tensor*> vec_kv_data = {&k_data, &v_data};
for (const Tensor* data_ptr : vec_kv_data) {
TVM_FFI_ICHECK_EQ((*data_ptr)->ndim, 4) << error_msg;
TVM_FFI_ICHECK_EQ((*data_ptr)->shape[0], num_layers_)
<< error_msg << " The number of layers mismatches.";
TVM_FFI_ICHECK_EQ((*data_ptr)->shape[1], end_pos - start_pos)
<< error_msg << " The sequence length mismatches.";
TVM_FFI_ICHECK_EQ((*data_ptr)->shape[2], num_kv_heads_)
<< error_msg << " The number of heads mismatches.";
TVM_FFI_ICHECK_EQ((*data_ptr)->shape[3], qk_head_dim_)
<< error_msg << " The number of head features mismatches.";
}
std::vector<int32_t> trace = seq.GetBlockTrace(global_block_pool_);
std::vector<int32_t> append_position_map;
append_position_map.reserve(seq.seq_length);
for (int32_t block_id : trace) {
const Block& block = global_block_pool_[block_id];
for (int i = 0; i < block.seq_length; ++i) {
int32_t offset =
i < block.sink_length ? i : i - block.sink_length + block.sliding_window_offset;
int page_id = block.page_ids[offset / page_size_];
int page_offset = offset % page_size_;
append_position_map.push_back(page_id * page_size_ + page_offset);
}
}
Tensor position_map_device = Tensor::Empty({end_pos - start_pos}, dtype_aux_, device_);
position_map_device.CopyFromBytes(
append_position_map.data() + start_pos,
(end_pos - start_pos) * ((dtype_aux_.bits * dtype_aux_.lanes + 7) / 8));
for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) {
TVM_FFI_ICHECK(attn_kinds_[layer_id] == AttnKind::kMHA)
<< "Only MHA is supported for DebugGetKV";
f_debug_get_kv_.value()(pages_[layer_id], position_map_device, k_data, v_data, layer_id);
}
}
void DebugGetKVMLA(int64_t seq_id, int64_t start_pos, int64_t end_pos, Tensor kv_data) final {
TVM_FFI_ICHECK(f_debug_get_kv_.defined())
<< "PageAttentionKVCache requires the `f_debug_get_kv` to be explicitly passed in when "
"initialization. Please construct the KV cache with `f_debug_get_kv`.";
const Sequence& seq = seq_map_.at(seq_id);
TVM_FFI_ICHECK_GE(start_pos, 0)
<< "DebugGetKV does not accept negative start_pos " << start_pos;
TVM_FFI_ICHECK_LE(end_pos, seq.seq_length) << "DebugGetKV does not accept out-of-range end_pos";
TVM_FFI_ICHECK_LT(start_pos, end_pos) << "DebugGetKV does not accept \"start_pos >= end_pos\"";
// kv_data: (num_layers, seq_length, qk_head_dim)
static constexpr const char* error_msg =
"DebugGetKV expects the kv_data in layout (num_layers, seq_length, qk_head_dim).";
TVM_FFI_ICHECK_EQ(kv_data->ndim, 3) << error_msg;
TVM_FFI_ICHECK_EQ(kv_data->shape[0], num_layers_)
<< error_msg << " The number of layers mismatches.";
TVM_FFI_ICHECK_EQ(kv_data->shape[1], end_pos - start_pos)
<< error_msg << " The sequence length mismatches.";
TVM_FFI_ICHECK_EQ(kv_data->shape[2], qk_head_dim_)
<< error_msg << " The number of head features mismatches.";
std::vector<int32_t> trace = seq.GetBlockTrace(global_block_pool_);
std::vector<int32_t> append_position_map;
append_position_map.reserve(seq.seq_length);
for (int32_t block_id : trace) {
const Block& block = global_block_pool_[block_id];
for (int i = 0; i < block.seq_length; ++i) {
int32_t offset =
i < block.sink_length ? i : i - block.sink_length + block.sliding_window_offset;
int page_id = block.page_ids[offset / page_size_];
int page_offset = offset % page_size_;
append_position_map.push_back(page_id * page_size_ + page_offset);
}
}
Tensor position_map_device = Tensor::Empty({end_pos - start_pos}, dtype_aux_, device_);
position_map_device.CopyFromBytes(
append_position_map.data() + start_pos,
(end_pos - start_pos) * ((dtype_aux_.bits * dtype_aux_.lanes + 7) / 8));
for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) {
TVM_FFI_ICHECK(attn_kinds_[layer_id] == AttnKind::kMLA)
<< "Only MHA is supported for DebugGetKVMLA";
f_debug_get_kv_.value()(pages_[layer_id], position_map_device, kv_data, layer_id);
}
}
void DebugSetKV(int64_t seq_id, int64_t start_pos, Tensor k_data, Tensor v_data) final {
TVM_FFI_ICHECK(false) << "DebugSetKV for PageAttentionKVCache not implemented yet.";
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.vm.PagedAttentionKVCache", PagedAttentionKVCacheObj,
AttentionKVCacheObj);
private:
/*! \brief Get a new free page and return its id. */
int32_t GetFreePage() {
// Find a page from the free page pools.
TVM_FFI_ICHECK(!free_page_ids_.empty()) << "The KV cache is full. No page can be allocated.";
int32_t page_id = free_page_ids_.back();
free_page_ids_.pop_back();
return page_id;
}
/*! \brief Get a new free block and return its index. */
int32_t GetFreeBlock() {
if (!free_block_idx_.empty()) {
int32_t block_idx = free_block_idx_.back();
free_block_idx_.pop_back();
global_block_pool_[block_idx].Reset();
TVM_FFI_ICHECK_EQ(global_block_pool_[block_idx].index, block_idx);
return block_idx;
}
int32_t block_idx = global_block_pool_.size();
global_block_pool_.push_back(Block(block_idx));
return block_idx;
}
void ConstructTokenTreeMask(const std::vector<Sequence*>& sequences,
const ffi::Shape& token_tree_parent_ptr,
const std::vector<std::vector<int32_t>>& block_ids_on_depths,
const std::vector<std::vector<int32_t>>& trailing_blocks) {
// Check whether the token tree of a sequence should be handled at the current depth.
auto check_for_sequence = [&](int seq_i, int depth) -> bool {
if (!append_before_attn_) {
return true;
}
// Check if the last block of the sequence is on the current depth.
if (block_ids_on_depths[depth][seq_i] == sequences[seq_i]->last_block_idx ||
(depth + 1 == kPagedKVCacheMaxBlockDepth && !trailing_blocks[seq_i].empty())) {
return true;
}
return false;
};
for (int d = 0; d < num_depths_; ++d) {
// We check if the token tree deteriorates to a chain,
// because chain cases can have simplified attention work flow.
TVM_FFI_ICHECK_LT(d, tree_attn_mask_host_.size());
TVM_FFI_ICHECK_LT(d, tree_attn_mn_indptr_host_.size());
HostMemoryVector& tree_attn_mn_indptr = tree_attn_mn_indptr_host_[d];
HostMemoryVector& tree_attn_mask = tree_attn_mask_host_[d];
std::vector<bool> seq_in_current_depth(cur_batch_size_, false);
tree_attn_mn_indptr.clear();
tree_attn_mask.clear();
std::fill(is_chain_on_depths_.begin(), is_chain_on_depths_.end(), true);
bool is_chain = true;
// - Construct the mn indptr array, which is the indptr of the mask size of each sequence.
tree_attn_mn_indptr.push_back(0);
TVM_FFI_ICHECK_EQ(sequences.size(), cur_batch_size_);
TVM_FFI_ICHECK_EQ(cur_append_lengths_.size(), cur_batch_size_);
int64_t token_tree_parent_ptr_offset = 0;
for (int i = 0; i < cur_batch_size_; ++i) {
int64_t append_length = cur_append_lengths_[i];
seq_in_current_depth[i] = check_for_sequence(i, d);
if (!seq_in_current_depth[i]) {
tree_attn_mn_indptr.push_back(tree_attn_mn_indptr.back());
token_tree_parent_ptr_offset += append_length; // Skip the token tree of this sequence.
continue;
}
// Update the token tree parent pointers.
TVM_FFI_ICHECK_LE(sequences[i]->token_tree_parent_ptr.size(),
global_block_pool_[sequences[i]->last_block_idx].seq_length)
<< "The token tree size is larger than the sequence length of the last block.";
std::copy(token_tree_parent_ptr.begin() + token_tree_parent_ptr_offset,
token_tree_parent_ptr.begin() + token_tree_parent_ptr_offset + append_length,
std::back_inserter(sequences[i]->token_tree_parent_ptr));
token_tree_parent_ptr_offset += append_length;
TVM_FFI_ICHECK_LE(sequences[i]->token_tree_parent_ptr.size(), kTreeAttnMaxTreeSize)
<< "The tree size is " << append_length << " which exceeds the maximum tree size limit "
<< kTreeAttnMaxTreeSize;
tree_attn_mn_indptr.push_back(tree_attn_mn_indptr.back() +
sequences[i]->token_tree_parent_ptr.size());
}
TVM_FFI_ICHECK_EQ(token_tree_parent_ptr.size(), token_tree_parent_ptr_offset)
<< "Invalid token tree size. The sum of \"append_lengths\" is "
<< token_tree_parent_ptr_offset << " while there are " << token_tree_parent_ptr.size()
<< " elements in \"token_tree_parent_ptr\".";
// - Construct the mask of each sequence.
for (int i = 0; i < cur_batch_size_; ++i) {
if (!seq_in_current_depth[i]) {
continue;
}
int64_t tree_size = sequences[i]->token_tree_parent_ptr.size();
std::vector<std::vector<int32_t>> mask;
std::vector<int32_t> depth;
mask.reserve(tree_size);
depth.reserve(tree_size);
sequences[i]->is_chain = true;
sequences[i]->accepted_indices_committed = false;
std::unordered_map<int, std::vector<int>> tree_parent_to_children;
std::vector<int> tree_roots;
for (int n = 0; n < tree_size; ++n) {
TVM_FFI_ICHECK_LT(sequences[i]->token_tree_parent_ptr[n], n)
<< "Invalid token tree. The parent of node " << n << " in tree " << i << " is "
<< sequences[i]->token_tree_parent_ptr[n] << ", which is not smaller than " << n;
TVM_FFI_ICHECK_GE(sequences[i]->token_tree_parent_ptr[n], -1)
<< "Invalid token tree. The parent of node " << n << " in tree " << i << " is "
<< sequences[i]->token_tree_parent_ptr[n];
if (sequences[i]->token_tree_parent_ptr[n] != n - 1) {
// The parent of the current node is not the last node.
// Therefore the tree is not a chain.
sequences[i]->is_chain = false;
is_chain = false;
}
tree_parent_to_children[sequences[i]->token_tree_parent_ptr[n]].push_back(n);
if (sequences[i]->token_tree_parent_ptr[n] != -1) {
depth.push_back(depth[sequences[i]->token_tree_parent_ptr[n]] + 1);
} else {
depth.push_back(0);
tree_roots.push_back(n);
}
}
std::vector<std::pair<int, int>> tree_order(tree_size);
int order = 0;
std::function<int(int)> tree_dfs = [&order, &tree_order, &tree_parent_to_children,
&tree_dfs](int node) -> int {
tree_order[node].first = order++;
int upper_bound = tree_order[node].first + 1;
for (int child : tree_parent_to_children[node]) {
upper_bound = std::max(upper_bound, tree_dfs(child));
}
tree_order[node].second = upper_bound;
return upper_bound;
};
for (auto root : tree_roots) {
tree_dfs(root);
}
for (int n = 0; n < tree_size; ++n) {
tree_attn_mask.push_back(tree_order[n].first);
tree_attn_mask.push_back(tree_order[n].second);
}
sequences[i]->token_tree_node_depths = std::move(depth);
}
is_chain_on_depths_[d] = is_chain;
if (!append_before_attn_) {
break;
}
}
}
/*!
* \brief Slide the KV cache window of the given sequence when
* it has sliding window enabled.
* \param seq The sequence to be slidden when
*/
void SlideWindowForSequence(Sequence* seq) {
// - No action when the sequence is not enabled for sliding window.
if (seq->sliding_window_size == -1 || !support_sliding_window_) {
return;
}
// - No action when the sequence length does not exceed the window size.
if (seq->seq_length <= seq->sliding_window_size) {
return;
}
int32_t length_to_slide = seq->seq_length - seq->sliding_window_size;
// - Get the last block of the sequence.
Block& block = global_block_pool_[seq->last_block_idx];
// - If the attention sink exists and the last block has no previous
// sink length, it means this is the first time we slide the sequence,
// and thus we set the sink length of the last block, the index of the
// first sliding page, and starting offset in first sliding page.
if (seq->last_block_attn_sink_size > 0 && block.sink_length == 0) {
TVM_FFI_ICHECK_EQ(block.sliding_window_offset, 0);
block.sink_length = seq->last_block_attn_sink_size;
block.sliding_window_offset = seq->last_block_attn_sink_size;
}
// - The sink pages cannot be slidden.
int32_t num_sink_pages = (block.sink_length + page_size_ - 1) / page_size_;
// - Compute the first sliding page index and in-page sliding window
// start offset in the first sliding page after sliding.
int32_t page_idx_after_sliding = (block.sliding_window_offset + length_to_slide) / page_size_;
int32_t page_start_offset_after_sliding =
(block.sliding_window_offset + length_to_slide) % page_size_;
// - Free the pages that are fully slidden.
while (page_idx_after_sliding > num_sink_pages) {
if (block.page_ids[num_sink_pages] != kPagedKVCacheTempPageId) {
free_page_ids_.push_back(block.page_ids[num_sink_pages]);
}
block.page_ids.erase(block.page_ids.begin() + num_sink_pages);
--page_idx_after_sliding;
}
// - The first sliding page after sliding is either the last sink page,
// or the page next to the last sink page.
TVM_FFI_ICHECK(page_idx_after_sliding == num_sink_pages - 1 ||
page_idx_after_sliding == num_sink_pages);
// - Update the length of the sequence and the block.
seq->seq_length = seq->sliding_window_size;
block.seq_length -= length_to_slide;
block.sliding_window_offset =
page_idx_after_sliding * page_size_ + page_start_offset_after_sliding;
TVM_FFI_ICHECK_GE(block.seq_length, block.sink_length);
TVM_FFI_ICHECK_GE(block.sliding_window_offset, block.sink_length);
TVM_FFI_ICHECK_EQ(
(block.sliding_window_offset + (block.seq_length - block.sink_length) + page_size_ - 1) /
page_size_,
block.page_ids.size());
}
/*!
* \brief Reserve extra append length in the last block of the given
* sequence, as preparation of the incoming KV cache append.
* New pages will be allocated to the block until the total
* capacity can cover the current sequence length (before reservation)
* plus the required append length.
* \param block_idx The index of the block to process.
* \param append_length The extra append length to reserve for the block.
* \note We apply sliding window in this function.
*/
void ReserveAppendLengthInSeq(Sequence* seq, int64_t append_length) {
int32_t block_idx = seq->last_block_idx;
Block& block = global_block_pool_[block_idx];
TVM_FFI_ICHECK_GT(append_length, 0) << "Append with length 0 is not allowed.";
TVM_FFI_ICHECK_EQ(block.external_ref_cnt, 1)
<< "The block is " << block.external_ref_cnt - 1
<< "-time referenced by other blocks, thus cannot accept new KV values.";
// ==================== Reserve ====================
// The reservation is based on the current sequence length.
// If "current sequence + append length" does not exceed the
// current capacity (number of pages * page size), no action is taken.
int64_t cur_npage = block.page_ids.size();
int64_t tgt_npage = (block.seq_length - block.sink_length + block.sliding_window_offset +
append_length + page_size_ - 1) /
page_size_;
for (int64_t page_idx = cur_npage; page_idx < tgt_npage; ++page_idx) {
// When sliding window is enabled for the seq, we can "borrow temporary pages (-1)",
// since the pages need to be slidden out might not have been released.
if (free_page_ids_.empty() && seq->sliding_window_size != -1 && support_sliding_window_) {
block.page_ids.push_back(kPagedKVCacheTempPageId);
} else {
block.page_ids.push_back(GetFreePage());
}
}
block.seq_length += append_length;
// ==================== Slide ====================
// Slide the sequences so that the pages exceed the sliding window are released.
SlideWindowForSequence(seq);
if (support_sliding_window_) {
for (int i = 0; i < static_cast<int>(block.page_ids.size()); ++i) {
if (block.page_ids[i] == kPagedKVCacheTempPageId) {
// Re-allocate the temporary pages after sliding window release.
block.page_ids[i] = GetFreePage();
}
}
}
dirty_aux_data_device_ = true;
}
/*! \brief Check whether BeginForward for kernels is needed. */
bool NeedKernelBeginForward() {
std::vector<AttnBackendFunc*> funcs = {f_attention_prefill_.get(),
f_attention_prefill_ragged_.get(),
f_attention_decode_.get(),
f_attention_prefill_sliding_window_.get(),
f_attention_decode_sliding_window_.get(),
f_attention_prefill_with_tree_mask_.get(),
f_attention_prefill_with_tree_mask_paged_kv_.get(),
f_mla_prefill_.get()};
for (AttnBackendFunc* func : funcs) {
if (func != nullptr && func->backend_kind == AttnBackendKind::kFlashInfer) {
return true;
}
}
return false;
}
/*! \brief Invoke the "begin forward" functions of underlying kernels. */
void KernelBeginForward() {
if (!NeedKernelBeginForward()) {
return;
}
auto it_layer_begin = attn_kinds_.begin() + layer_id_begin_offset_;
auto it_layer_end = attn_kinds_.begin() + layer_id_end_offset_;
if (std::find(it_layer_begin, it_layer_end, AttnKind::kMHA) != it_layer_end) {
MHAKernelBeginForward();
}
if (std::find(it_layer_begin, it_layer_end, AttnKind::kMLA) != it_layer_end) {
MLAKernelBeginForward();
}
}
/*! \brief KernelBeginForward for multi-head attention. */
void MHAKernelBeginForward() {
if (!append_before_attn_) {
if (is_chain_on_depths_[0] && f_attention_prefill_ragged_ != nullptr &&
f_attention_prefill_ragged_->backend_kind == AttnBackendKind::kFlashInfer) {
f_attention_prefill_ragged_->BeginForward(
temp_float_attn_workspace_, temp_int_attn_workspace_[0],
temp_int_pinned_attn_workspace_[0], &cur_append_lengths_indptr_host_,
&cur_append_lengths_indptr_host_, cur_batch_size_,
cur_append_lengths_indptr_host_.back(), num_qo_heads_, num_kv_heads_, qk_head_dim_,
v_head_dim_, /*causal=*/true, copy_stream_);
}
}
for (int d = 0; d < num_depths_; ++d) {
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
continue;
}
TVM_FFI_ICHECK(!support_sliding_window_ || !support_layer_sliding_window_)
<< "Kernel BeginForward doesn't support sliding window.";
if (use_decode_kernel_[d]) {
if (f_attention_decode_ != nullptr &&
f_attention_decode_->backend_kind == AttnBackendKind::kFlashInfer) {
f_attention_decode_->BeginForward(
d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1],
temp_int_pinned_attn_workspace_[d + 1], &page_indptr_on_depths_host_[d],
cur_batch_size_, page_size_, num_qo_heads_, num_kv_heads_, qk_head_dim_, v_head_dim_,
rope_mode_, kv_dtype_, kv_dtype_, copy_stream_);
}
} else {
if (f_attention_prefill_ != nullptr &&
f_attention_prefill_->backend_kind == AttnBackendKind::kFlashInfer) {
f_attention_prefill_->BeginForward(
d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1],
temp_int_pinned_attn_workspace_[d + 1], &qo_indptr_on_depths_host_[d],
&page_indptr_on_depths_host_[d], &last_page_len_on_depths_host_[d],
static_cast<int64_t>(qo_indptr_on_depths_host_[d].size()) - 1,
cur_append_lengths_indptr_host_.back(), page_size_, num_qo_heads_, num_kv_heads_,
qk_head_dim_, v_head_dim_, /*causal=*/false, copy_stream_);
}
}
}
}
/*! \brief KernelBeginForward for multi-head latent attention. */
void MLAKernelBeginForward() {
if (!append_before_attn_) {
if (is_chain_on_depths_[0]) {
if (f_attention_prefill_ragged_ != nullptr &&
f_attention_prefill_ragged_->backend_kind == AttnBackendKind::kFlashInfer) {
f_attention_prefill_ragged_->BeginForward(
temp_float_attn_workspace_, temp_int_attn_workspace_[0],
temp_int_pinned_attn_workspace_[0], &cur_append_lengths_indptr_host_,
&cur_append_lengths_indptr_host_, cur_batch_size_,
cur_append_lengths_indptr_host_.back(), num_qo_heads_, num_qo_heads_, qk_head_dim_,
v_head_dim_, /*causal=*/true, copy_stream_);
}
}
}
for (int d = 0; d < num_depths_; ++d) {
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
continue;
}
TVM_FFI_ICHECK(!support_sliding_window_)
<< "Kernel BeginForward doesn't support sliding window.";
if (f_mla_prefill_ != nullptr &&
f_mla_prefill_->backend_kind == AttnBackendKind::kFlashInfer) {
f_mla_prefill_->BeginForward(
d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1],
temp_int_pinned_attn_workspace_[d + 1], &qo_indptr_on_depths_host_[d],
&page_indptr_on_depths_host_[d], &last_page_len_on_depths_host_[d],
static_cast<int64_t>(qo_indptr_on_depths_host_[d].size()) - 1,
cur_append_lengths_indptr_host_.back(), page_size_, num_qo_heads_, num_kv_heads_,
qk_head_dim_, v_head_dim_, /*causal=*/false, copy_stream_);
}
}
}
/*!
* \brief Compute attention for between the input q data and the
* input k/v data and the k/v data in cache on the given layer.
*/
void AttentionInternal(int64_t layer_id, Tensor q_data, Tensor k_data, Tensor v_data,
Tensor output, double sm_scale) {
int64_t local_layer_id = layer_id - layer_id_begin_offset_;
TVM_FFI_ICHECK_GE(local_layer_id, 0);
TVM_FFI_ICHECK_LT(local_layer_id, num_layers_);
bool is_first_kernel = true;
if (!append_before_attn_) {
// The first part of attention, which only involves the q and the newly appended k/v.
is_first_kernel = false;
MHASelfAttnInternal(q_data, k_data, v_data, output, merged_attn_lse_view_, sm_scale);
}
bool self_attn_computed = !is_first_kernel;
bool cross_attn_computed = MHACrossAttnInternal(
local_layer_id, q_data, output, merged_attn_lse_view_, sm_scale, is_first_kernel);
TVM_FFI_ICHECK(self_attn_computed || cross_attn_computed)
<< "Both self-attention and cross-attention are not computed.";
}
void MHASelfAttnInternal(Tensor q_data, Tensor k_data, Tensor v_data, Tensor o_data,
Tensor lse_data, double sm_scale) {
if (is_chain_on_depths_[0]) {
// If the batch does not form a tree, use raggedness prefill kernel.
TVM_FFI_ICHECK_NOTNULL(f_attention_prefill_ragged_);
f_attention_prefill_ragged_->MHA(
q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_,
q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, /*causal=*/true, rope_mode_,
rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_);
} else {
// The batch requires tree attention.
TVM_FFI_ICHECK(f_attention_prefill_with_tree_mask_ != nullptr)
<< "Function \"f_attention_prefill_with_tree_mask_\" is not defined.";
TVM_FFI_ICHECK(tree_attn_mask_view_[0].defined());
TVM_FFI_ICHECK(tree_attn_mn_indptr_view_[0].defined());
f_attention_prefill_with_tree_mask_->MHA(
q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_,
q_rope_position_map_view_, tree_attn_mn_indptr_view_[0], tree_attn_mask_view_[0],
rope_mode_, rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_);
}
}
void MLASelfAttnInternal(Tensor q_data, Tensor k_data, Tensor v_data, Tensor o_data,
Tensor lse_data, double sm_scale) {
TVM_FFI_ICHECK(is_chain_on_depths_[0]) << "Tree attn not able for MLA for now.";
// If the batch does not form a tree, use raggedness prefill kernel.
TVM_FFI_ICHECK_NOTNULL(f_attention_prefill_ragged_);
f_attention_prefill_ragged_->MHA(
q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_,
q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, /*causal=*/true, RoPEMode::kNone,
rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_);
}
/*! \brief Compute cross-attention for MHA. Return if there is effective computation. */
bool MHACrossAttnInternal(int64_t local_layer_id, Tensor q_data, Tensor o_data, Tensor lse_data,
double sm_scale, bool is_first_kernel) {
std::unique_ptr<PagedPrefillFunc>& f_prefill =
(!support_sliding_window_ &&
attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding)
? f_attention_prefill_
: f_attention_prefill_sliding_window_;
std::unique_ptr<PagedDecodeFunc>& f_decode =
(!support_sliding_window_ &&
attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding)
? f_attention_decode_
: f_attention_decode_sliding_window_;
TVM_FFI_ICHECK_GE(num_depths_, 1)
<< "The number of effective depths must be greater or equal to 1.";
bool cross_attn_computed = false;
for (int d = 0; d < num_depths_; ++d) {
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
continue;
}
Tensor attn_output;
Tensor attn_lse;
if (is_first_kernel) {
attn_output = o_data;
attn_lse = lse_data;
} else {
attn_output = temp_attn_output_view_;
attn_lse = temp_attn_lse_view_;
}
// If layer is sliding window, use sliding window index pointer/indices
Tensor page_indptr;
Tensor page_indices;
Tensor length_info;
Tensor k_rope_pos;
double rotary_theta;
double rotary_scale;
if (attn_kinds_[local_layer_id + layer_id_begin_offset_] == AttnKind::kMHASliding) {
page_indptr = page_indptr_sliding_window_on_depths_view_[d];
page_indices = page_indices_sliding_window_on_depths_view_[d];
length_info = layer_sliding_window_length_info_on_depths_view_[d];
k_rope_pos = k_rope_pos_offset_sliding_window_view_[d];
rotary_theta = 10000;
rotary_scale = 1;
} else {
page_indptr = page_indptr_on_depths_view_[d];
page_indices = page_indices_on_depths_view_[d];
length_info = length_info_on_depths_view_[d];
k_rope_pos = k_rope_pos_offset_view_[d];
rotary_theta = rotary_theta_;
rotary_scale = rotary_scale_;
}
if (append_before_attn_ && !is_chain_on_depths_[d]) {
TVM_FFI_ICHECK_NOTNULL(f_attention_prefill_with_tree_mask_paged_kv_);
f_attention_prefill_with_tree_mask_paged_kv_->MHA(
q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, page_indices,
length_info, k_rope_pos, q_rope_position_map_view_, tree_attn_mn_indptr_view_[d],
tree_attn_mask_view_[d], rope_mode_, rotary_scale, rotary_theta, sm_scale, attn_output,
attn_lse, compute_stream_);
} else if (use_decode_kernel_[d]) {
// Use decode kernel for depth d
TVM_FFI_ICHECK_NOTNULL(f_decode);
f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr, page_indices, length_info,
k_rope_pos, q_rope_position_map_view_, rope_mode_, rotary_scale, rotary_theta,
sm_scale, attn_output, attn_lse, compute_stream_);
} else {
// Use prefill kernel for depth d
TVM_FFI_ICHECK_NOTNULL(f_prefill);
f_prefill->MHA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr,
page_indices, length_info, q_rope_position_map_view_, k_rope_pos,
/*causal=*/false,
/*rotary_mode=*/rope_mode_, rotary_scale, rotary_theta, sm_scale,
attn_output, attn_lse, compute_stream_);
}
if (!is_first_kernel) {
f_merge_inplace_[0](o_data, lse_data, temp_attn_output_view_, temp_attn_lse_view_);
} else {
is_first_kernel = false;
}
cross_attn_computed = true;
}
return cross_attn_computed;
}
/*! \brief Compute cross-attention for MLA. Return if there is effective computation. */
bool MLACrossAttnInternal(int64_t local_layer_id, Tensor q_data, Tensor o_data, Tensor lse_data,
double sm_scale) {
TVM_FFI_ICHECK_GE(num_depths_, 1)
<< "The number of effective depths must be greater or equal to 1.";
bool is_first_kernel = true;
for (int d = 0; d < num_depths_; ++d) {
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
continue;
}
Tensor attn_output;
Tensor attn_lse;
if (is_first_kernel) {
attn_output = o_data;
attn_lse = lse_data;
} else {
attn_output = temp_attn_output_view_;
attn_lse = temp_attn_lse_view_;
}
TVM_FFI_ICHECK(is_chain_on_depths_[d]) << "Tree attn not able for MLA for now.";
TVM_FFI_ICHECK_NOTNULL(f_mla_prefill_);
f_mla_prefill_->MLA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id],
page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d],
length_info_on_depths_view_[d], /*causal=*/false, sm_scale, attn_output,
attn_lse, compute_stream_);
if (!is_first_kernel) {
f_merge_inplace_[0](o_data, lse_data, temp_attn_output_view_, temp_attn_lse_view_);
} else {
is_first_kernel = false;
}
}
return !is_first_kernel;
}
/*! \brief Synchronize the copy stream and the compute stream. */
void ComputeStreamWaitForCopyStream() {
if (!dirty_aux_data_device_) {
// If the auxiliary data is already synced, return and no need to sync again.
return;
}
// - Sync Tensors to GPU.
SyncAuxArrayToDevice();
KernelBeginForward();
// - Clear the dirty flag.
dirty_aux_data_device_ = false;
// - If there is no particular copy stream, no action is needed.
if (copy_stream_ == nullptr) {
return;
}
// - Sync two streams.
DeviceAPI::Get(device_)->SyncStreamFromTo(device_, copy_stream_, compute_stream_);
}
/*!
* \brief Synchronize auxiliary arrays to device.
* \note This method resets the dirty flag to false, and needs to be
* invoked before running attention computation on device.
*/
void SyncAuxArrayToDevice() {
TVM_FFI_ICHECK(dtype_aux_.bits == 32 && dtype_aux_.code == kDLInt);
int64_t total_append_length = 0;
int num_sequences = cur_append_lengths_.size();
cur_append_lengths_indptr_host_.clear();
cur_append_lengths_indptr_host_.push_back(0);
for (int i = 0; i < num_sequences; ++i) {
cur_append_lengths_indptr_host_.push_back(cur_append_lengths_indptr_host_.back() +
cur_append_lengths_[i]);
}
total_append_length = cur_append_lengths_indptr_host_.back();
TVM_FFI_ICHECK_EQ(total_append_length, append_position_map_host_.size());
TVM_FFI_ICHECK_EQ(total_append_length, kv_transfer_remote_position_map_host_.size());
TVM_FFI_ICHECK_EQ(total_append_length, kv_transfer_recver_id_host_.size());
// - Reset the copy.
aux_data_manager_->ResetAttnAuxDataCopy();
// 1. q_rope_position_map
// q_rope_position_map has to be synced first so that it has a 0 byte offset
TVM_FFI_ICHECK_EQ(q_rope_position_map_host_.size(), total_append_length);
q_rope_position_map_view_ = aux_data_manager_->CopyQRoPEPosMapAsync(&q_rope_position_map_host_);
// 2. qo_indptr_on_depths
for (int d = 0; d < num_depths_; ++d) {
qo_indptr_on_depths_view_[d] =
aux_data_manager_->CopyQOIndptrOnDepthAsync(&qo_indptr_on_depths_host_[d], d);
}
// 3. page_indptr_on_depths
for (int d = 0; d < num_depths_; ++d) {
TVM_FFI_ICHECK_EQ(page_indptr_on_depths_host_[d].size(), qo_indptr_on_depths_host_[d].size());
page_indptr_on_depths_view_[d] =
aux_data_manager_->CopyPageIndptrOnDepthAsync(&page_indptr_on_depths_host_[d], d);
}
// 4. page_indices_on_depths
for (int d = 0; d < num_depths_; ++d) {
TVM_FFI_ICHECK_EQ(page_indices_on_depths_host_[d].size(),
page_indptr_on_depths_host_[d].back());
page_indices_on_depths_view_[d] =
aux_data_manager_->CopyPageIndicesOnDepthAsync(&page_indices_on_depths_host_[d], d);
}
// If per layer sliding window exists, must copy additional vectors
if (support_layer_sliding_window_) {
// 5. page_indptr_sliding_window_on_depths
for (int d = 0; d < num_depths_; ++d) {
TVM_FFI_ICHECK_EQ(page_indptr_sliding_window_on_depths_host_[d].size(),
qo_indptr_on_depths_host_[d].size());
page_indptr_sliding_window_on_depths_view_[d] =
aux_data_manager_->CopyPageIndptrOnDepthAsync(
&page_indptr_sliding_window_on_depths_host_[d], d);
}
// 6. page_indices_sliding_window_on_depths
for (int d = 0; d < num_depths_; ++d) {
TVM_FFI_ICHECK_EQ(page_indices_sliding_window_on_depths_host_[d].size(),
page_indptr_sliding_window_on_depths_host_[d].back());
page_indices_sliding_window_on_depths_view_[d] =
aux_data_manager_->CopyPageIndicesOnDepthAsync(
&page_indices_sliding_window_on_depths_host_[d], d);
}
}
// 7. length_info_on_depths
// last_page_len_on_depths_host_;
// sliding_window_offset_on_depths_host_;
// sink_size_on_depths_host_;
for (int d = 0; d < num_depths_; ++d) {
int num_seq_on_layer = static_cast<int>(qo_indptr_on_depths_host_[d].size()) - 1;
TVM_FFI_ICHECK_EQ(last_page_len_on_depths_host_[d].size(), num_seq_on_layer);
TVM_FFI_ICHECK_EQ(sliding_window_offset_on_depths_host_[d].size(), num_seq_on_layer);
TVM_FFI_ICHECK_EQ(sink_size_on_depths_host_[d].size(), num_seq_on_layer);
if (!support_sliding_window_) {
// Sliding window is not enabled, so we first copy "last_page_len".
length_info_on_depths_view_[d] =
aux_data_manager_->CopyLastPageLenOnDepthAsync(&last_page_len_on_depths_host_[d], d);
} else {
// Sliding window is enabled,
length_info_on_depths_view_[d] = aux_data_manager_->CopyLengthInfoOnDepthAsync(
&last_page_len_on_depths_host_[d], &sliding_window_offset_on_depths_host_[d],
&sink_size_on_depths_host_[d], d);
}
if (support_layer_sliding_window_) {
layer_sliding_window_length_info_on_depths_view_[d] =
aux_data_manager_->CopyLengthInfoOnDepthAsync(&last_page_len_on_depths_host_[d],
&sliding_window_offset_on_depths_host_[d],
&sink_size_on_depths_host_[d], d);
}
}
// 6. k_rope_pos_offset_on_depths
for (int d = 0; d < num_depths_; ++d) {
TVM_FFI_ICHECK_EQ(k_rope_pos_offset_on_depths_host_[d].size() + 1,
qo_indptr_on_depths_host_[d].size());
k_rope_pos_offset_view_[d] = aux_data_manager_->CopyKRoPEPosOffsetOnDepthAsync(
&k_rope_pos_offset_on_depths_host_[d], d);
if (support_layer_sliding_window_) {
TVM_FFI_ICHECK_EQ(k_rope_pos_offset_sliding_window_on_depths_host_[d].size() + 1,
qo_indptr_on_depths_host_[d].size());
k_rope_pos_offset_sliding_window_view_[d] =
aux_data_manager_->CopyKRoPEPosOffsetOnDepthAsync(
&k_rope_pos_offset_sliding_window_on_depths_host_[d], d);
}
}
// 7. cur_append_lengths_indptr
cur_append_length_indptr_view_ =
aux_data_manager_->CopyCurAppendLengthIndptrAsync(&cur_append_lengths_indptr_host_);
// 8. k_ragged_rope_pos_offset
TVM_FFI_ICHECK_EQ(k_ragged_rope_pos_offset_host_.size(), num_sequences);
k_ragged_rope_pos_offset_view_ =
aux_data_manager_->CopyKRaggedRoPEPosOffsetAsync(&k_ragged_rope_pos_offset_host_);
// 9. append_position_map
append_position_map_view_ =
aux_data_manager_->CopyAppendPositionMapAsync(&append_position_map_host_);
// 10. kv_transfer_remote_position_map
kv_transfer_remote_position_map_view_ = aux_data_manager_->CopyKVTransferRemotePositionMapAsync(
&kv_transfer_remote_position_map_host_);
// 11. kv_transfer_recver_id
kv_transfer_recver_id_view_ =
aux_data_manager_->CopyKVTransferRecverIDAsync(&kv_transfer_recver_id_host_);
// 12. kv_transfer_page_to_page_local_position_map
kv_transfer_page_to_page_local_position_map_view_ =
aux_data_manager_->CopyKVTransferPage2PageLocalPositionMapAsync(
&kv_transfer_page_to_page_local_position_map_host_);
// 13. kv_transfer_page_to_page_remote_position_map
kv_transfer_page_to_page_remote_position_map_view_ =
aux_data_manager_->CopyKVTransferPage2PageRemotePositionMapAsync(
&kv_transfer_page_to_page_remote_position_map_host_);
// 14. kv_transfer_page_to_page_recver_id
kv_transfer_page_to_page_recver_id_view_ =
aux_data_manager_->CopyKVTransferPage2PageRecverIDAsync(
&kv_transfer_page_to_page_recver_id_host_);
// 15. tree_attn_mask and tree_attn_mn_indptr
for (int d = 0; d < num_depths_; ++d) {
if (!is_chain_on_depths_[d]) {
tree_attn_mask_view_[d] =
aux_data_manager_->CopyTreeAttnMaskOnDepthAsync(&tree_attn_mask_host_[d], d);
tree_attn_mn_indptr_view_[d] =
aux_data_manager_->CopyTreeAttnMNIndptrOnDepthAsync(&tree_attn_mn_indptr_host_[d], d);
}
}
// 16. Create view for temporary arrays for attention computation.
temp_attn_output_view_ = temp_attn_output_device_.CreateView(
{total_append_length, num_qo_heads_, v_head_dim_}, temp_attn_output_device_->dtype);
temp_attn_lse_view_ = temp_attn_lse_device_.CreateView({total_append_length, num_qo_heads_},
temp_attn_lse_device_->dtype);
merged_attn_lse_view_ = merged_attn_lse_device_.CreateView({total_append_length, num_qo_heads_},
merged_attn_lse_device_->dtype);
// - Commit the copy.
aux_data_manager_->CommitAttnAuxDataCopy();
// - Reset the dirty flag to false.
dirty_aux_data_device_ = false;
}
}; // namespace vm
//-------------------------------------------------
// Register runtime functions
//-------------------------------------------------
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed(
"vm.builtin.paged_attention_kv_cache_create", [](ffi::PackedArgs args, ffi::Any* rv) {
// Todo: cuda graph arg
TVM_FFI_ICHECK(args.size() == 28 || args.size() == 29)
<< "Invalid number of KV cache constructor args: " << args.size();
ffi::Shape cache_config = args[0].cast<ffi::Shape>();
ffi::Shape layer_indptr_tuple = args[1].cast<ffi::Shape>();
int num_groups = 1;
int group_id = 0;
if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) {
// In the Disco worker thread
num_groups = disco_worker->num_groups;
group_id = disco_worker->worker_id / (disco_worker->num_workers / num_groups);
}
TVM_FFI_ICHECK_EQ(layer_indptr_tuple.size(), num_groups + 1);
int64_t num_layers = layer_indptr_tuple[group_id + 1] - layer_indptr_tuple[group_id];
int64_t layer_id_begin_offset = layer_indptr_tuple[group_id];
int64_t layer_id_end_offset = layer_indptr_tuple[group_id + 1];
int64_t num_qo_heads = args[2].cast<int64_t>();
int64_t num_kv_heads = args[3].cast<int64_t>();
int64_t qk_head_dim = args[4].cast<int64_t>();
int64_t v_head_dim = args[5].cast<int64_t>();
ffi::Shape attn_kinds = args[6].cast<ffi::Shape>();
bool enable_kv_transfer = args[7].cast<bool>();
int rope_mode = args[8].cast<int>();
double rotary_scale = args[9].cast<double>();
double rotary_theta = args[10].cast<double>();
ffi::Optional<Tensor> rope_ext_factors = std::nullopt; // args[11]
Tensor init = args[12].cast<Tensor>();
ffi::Optional<ffi::Function> f_transpose_append_mha = std::nullopt; // args[13]
ffi::Optional<ffi::Function> f_transpose_append_mla = std::nullopt; // args[14]
std::unique_ptr<RaggedPrefillFunc> f_attention_prefill_ragged =
ConvertRaggedPrefillFunc(args[15].cast<ffi::Array<ffi::Any>>(), AttnKind::kMHA);
std::unique_ptr<PagedPrefillFunc> f_attention_prefill =
ConvertPagedPrefillFunc(args[16].cast<ffi::Array<ffi::Any>>(), AttnKind::kMHA);
std::unique_ptr<PagedDecodeFunc> f_attention_decode =
ConvertPagedDecodeFunc(args[17].cast<ffi::Array<ffi::Any>>(), AttnKind::kMHA);
std::unique_ptr<PagedPrefillFunc> f_attention_prefill_sliding_window =
ConvertPagedPrefillFunc(args[18].cast<ffi::Array<ffi::Any>>(), AttnKind::kMHA);
std::unique_ptr<PagedDecodeFunc> f_attention_decode_sliding_window =
ConvertPagedDecodeFunc(args[19].cast<ffi::Array<ffi::Any>>(), AttnKind::kMHA);
std::unique_ptr<PagedPrefillTreeMaskFunc> f_attention_prefill_with_tree_mask_paged_kv =
ConvertPagedPrefillTreeMaskFunc(args[20].cast<ffi::Array<ffi::Any>>(), AttnKind::kMHA);
std::unique_ptr<RaggedPrefillTreeMaskFunc> f_attention_prefill_with_tree_mask =
ConvertRaggedPrefillTreeMaskFunc(args[21].cast<ffi::Array<ffi::Any>>(), AttnKind::kMHA);
std::unique_ptr<PagedPrefillFunc> f_mla_prefill =
ConvertPagedPrefillFunc(args[22].cast<ffi::Array<ffi::Any>>(), AttnKind::kMLA);
ffi::Array<ffi::Function> f_merge_inplace = args[23].cast<ffi::Array<ffi::Function>>();
ffi::Function f_split_rotary = args[24].cast<ffi::Function>();
ffi::Function f_copy_single_page = args[25].cast<ffi::Function>();
ffi::Function f_debug_get_kv = args[26].cast<ffi::Function>();
ffi::Function f_compact_copy = args[27].cast<ffi::Function>();
if (auto opt_nd = args[11].as<Tensor>()) {
rope_ext_factors = opt_nd.value();
}
auto f_convert_optional_packed_func = [&args](int arg_idx) -> ffi::Optional<ffi::Function> {
if (auto opt_func = args[arg_idx].as<ffi::Function>()) {
return opt_func.value();
}
return std::nullopt;
};
f_transpose_append_mha = f_convert_optional_packed_func(13);
f_transpose_append_mla = f_convert_optional_packed_func(14);
TVM_FFI_ICHECK(!f_merge_inplace.empty()) << "Merge inplace function is not defined.";
std::vector<AttnKind> attn_kinds_vec;
attn_kinds_vec.reserve(attn_kinds.size());
for (int64_t attn_kind : attn_kinds) {
attn_kinds_vec.push_back(static_cast<AttnKind>(attn_kind));
}
TVM_FFI_ICHECK_EQ(cache_config.size(), 5);
int64_t reserved_num_seqs = cache_config[0];
int64_t total_token_capacity = cache_config[1];
int64_t prefill_chunk_size = cache_config[2];
int64_t page_size = cache_config[3];
bool support_sliding_window = cache_config[4];
int64_t num_total_pages = (total_token_capacity + page_size - 1) / page_size + 1;
if (support_sliding_window) {
// When sliding window is enabled, each sequence may use two more pages at most.
num_total_pages += reserved_num_seqs * 2;
}
// NOTE: We will remove this legacy construction after finishing the transition phase.
// Some `ffi::Function()` here are placeholders that will be filled.
ObjectPtr<PagedAttentionKVCacheObj> n = ffi::make_object<PagedAttentionKVCacheObj>(
page_size, num_layers, layer_id_begin_offset, layer_id_end_offset, num_qo_heads,
num_kv_heads, qk_head_dim, v_head_dim, attn_kinds_vec, reserved_num_seqs,
num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode),
rotary_scale, rotary_theta, std::move(rope_ext_factors), enable_kv_transfer, //
init->dtype, init->device, //
std::move(f_transpose_append_mha), std::move(f_transpose_append_mla),
std::move(f_compact_copy), std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill), std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window),
std::move(f_attention_prefill_with_tree_mask_paged_kv), //
std::move(f_attention_prefill_with_tree_mask), //
std::move(f_mla_prefill), std::move(f_merge_inplace), std::move(f_split_rotary),
std::move(f_copy_single_page), std::move(f_debug_get_kv));
*rv = AttentionKVCache(std::move(n));
});
}
} // namespace vm
} // namespace runtime
} // namespace tvm