blob: afb962e4fc6f9db58e5819738a991d534241f137 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/runtime/vm/attn_utils.h
* \brief Data structure and utilities for KV cache.
*/
#ifndef TVM_RUNTIME_VM_ATTN_UTILS_H_
#define TVM_RUNTIME_VM_ATTN_UTILS_H_
#include <tvm/runtime/tensor.h>
#include <algorithm>
#include <limits>
#include <utility>
#include <vector>
#if defined(OPENCL_ENABLE_HOST_PTR)
#include "../opencl/opencl_common.h"
#endif
namespace tvm {
namespace runtime {
namespace vm {
/*!
* \brief The maximum allowed block depth (a.k.a. number of common
* prefixes) in paged KV cache.
*/
constexpr const int kPagedKVCacheMaxBlockDepth = 2;
/*! \brief The maximum tree size of a single sequence in tree attention. */
constexpr const int kTreeAttnMaxTreeSize = 256;
/*! \brief The 1MB workspace size for integer attention auxiliary data. */
constexpr const int kIntAttnWorkspaceByte = 8 * 1024 * 1024;
/*! \brief The 128MB workspace size for floating-point attention auxiliary data. */
constexpr const int kFloatAttnWorkspaceByte = 768 * 1024 * 1024;
/*! \brief The id of the temporary logical page, which is useful for sliding window. */
constexpr const int kPagedKVCacheTempPageId = -1;
/*!
* \brief The supported attention kinds in PagedKVCache.
* "MHA" means multi-head attention, multi-query attention and grouped query attention in general.
* "MLA" means multi-head latent attention.
* "LinearAttn" means linear attention.
*/
enum class AttnKind : int {
kMHA = 0,
kMLA = 1,
kLinearAttn = 2,
kMHASliding = 3,
};
/*! \brief Given the attention kind and other metadata, return the one-layer KV cache shape. */
inline ffi::Shape GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, int num_sequence,
int64_t num_kv_heads, int64_t page_size, int64_t qk_head_dim,
int64_t v_head_dim) {
if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) {
// Ignore v_head_dim since multi-head attention requires K/V to have the same head dim.
return {num_total_pages, 2, num_kv_heads, page_size, qk_head_dim};
} else if (attn_kind == AttnKind::kMLA) {
return {num_total_pages, page_size, qk_head_dim};
} else if (attn_kind == AttnKind::kLinearAttn) {
return {num_sequence, num_kv_heads, qk_head_dim, v_head_dim};
}
TVM_FFI_ICHECK(false);
return ffi::Shape();
}
/*!
* \brief The block structure in paged KV cache with common prefix support.
* Each block contains a list of pages for cached KV data.
* If a block has `n` pages, the first `n - 1` pages must be
* full, and only the last page can be partially filled.
*
* To support common prefix, each sequence in KV cache is represented
* as one or more blocks, where the common prefix is a standalone
* block among.
*
* Each block has a parent block when it uses a prefix.
*/
struct Block {
/*!
* \brief The ids of the pages in the block.
* Each page can only be used by a unique block (in other
* words, different blocks do not share pages).
*/
std::vector<int32_t> page_ids;
/*! \brief The total sequence length in the block. */
int32_t seq_length = 0;
/*!
* \brief The start position in sequence of this block.
* This is the absolute position in the sequence for RoPE computation.
*/
int32_t start_pos = 0;
/*!
* \brief The current attention sink length of the block.
* It means the **first** sink size elements will be pinned
* in the KV cache even when sliding window is enabled.
*/
int32_t sink_length = 0;
/*!
* \brief The start offset of the sliding window in the block.
* It is always 0 when sliding window attn is not enabled.
*/
int32_t sliding_window_offset = 0;
/*! \brief The global index of the block. */
const int32_t index;
/*!
* \brief The global index of the parent block of this block, or -1
* if the block does not have a parent. */
int32_t parent_idx = -1;
/*!
* \brief The external reference counter of the block.
* When a block is externally referred by some block,
* we do not allow appending new KV values to this block.
*/
int external_ref_cnt = 0;
explicit Block(int32_t index) : index(index) {}
/*! \brief Reset the block data. */
void Reset() {
page_ids.clear();
seq_length = 0;
start_pos = 0;
sink_length = 0;
sliding_window_offset = 0;
parent_idx = -1;
external_ref_cnt = 0;
}
};
struct KVTransferMetadata {
int64_t start = std::numeric_limits<int64_t>::max();
std::vector<int64_t> remote_position_map;
int32_t recver_pe_offset = -1;
std::vector<int64_t> local_position_map;
};
/*!
* \brief The sequence structure in paged KV cache with common prefix support.
* Each sequence contains one or more blocks to support common prefix.
*/
struct Sequence {
/*!
* \brief The global index of the last block of the sequence.
* We only store the last block, since all the blocks can be
* tracked with the `parent` field of Block.
*/
int32_t last_block_idx;
/*!
* \brief The total sequence length of the sequence.
* It is the sum of lengths of all its blocks.
*/
int32_t seq_length = 0;
/*!
* \brief The sliding window size of the sequence, or -1 if sliding window is not enabled.
* When a sequence is enabled for sliding window, it can no longer be forked.
*/
int sliding_window_size = -1;
/*!
* \brief The attention sink size of the last block of the sequence.
* The **first** sink size elements of the last block will be pinned
* in the KV cache even when sliding window is enabled.
*/
int last_block_attn_sink_size = 0;
/*! \brief Whether the current appended tokens form a chain (not a tree). */
bool is_chain = true;
/*! \brief The token tree parent pointer array of the current appended tokens. */
std::vector<int32_t> token_tree_parent_ptr;
/*! \brief The depth of each node in the token tree. */
std::vector<int32_t> token_tree_node_depths;
/*! \brief The metadata of kv transfer*/
KVTransferMetadata kv_transfer_metadata;
/*!
* \brief A boolean denoting whether the accepted token tree indices of
* this sequence are committed
*/
bool accepted_indices_committed = true;
explicit Sequence(std::vector<Block>* global_block_pool, int32_t last_block_idx) {
++global_block_pool->at(last_block_idx).external_ref_cnt;
this->last_block_idx = last_block_idx;
int32_t block_ptr = last_block_idx;
// Go through each block in the sequence, sum up the length.
while (true) {
const Block& block = global_block_pool->at(block_ptr);
this->seq_length += block.seq_length;
if (block.parent_idx == -1) {
break;
}
block_ptr = block.parent_idx;
}
}
std::vector<int32_t> GetBlockTrace(const std::vector<Block>& global_block_pool) const {
std::vector<int32_t> trace;
// Get the trace from the last block of the sequence to the root block.
int32_t block_ptr = last_block_idx;
while (block_ptr != -1) {
trace.push_back(block_ptr);
block_ptr = global_block_pool[block_ptr].parent_idx;
}
// Reverse the trace so that it starts from the root block.
std::reverse(trace.begin(), trace.end());
return trace;
}
};
/*!
* \brief For the given list of sequences, check the block trace of
* each sequence, and return the blocks ids used by the sequences
* on each depth. And if the depth is larger than the kPagedKVCacheMaxBlockDepth,
* the exceeding blocks will concatenate and output separately.
* More precisely, the inner returned vector contains the block ids
* used by the sequences on a certain depth (or "-1" if a sequence
* has fewer depth). The outer returned vector contains the inner
* vectors from the lowest depth to the highest depth.
*/
inline std::pair<std::vector<std::vector<int32_t>>, std::vector<std::vector<int32_t>>>
GetBlockIdsOnDepth(const std::vector<Sequence*>& sequences,
const std::vector<Block>& global_block_pool, int64_t batch_size) {
// - Get the trace of each sequence.
int64_t num_depths = 0;
std::vector<std::vector<int32_t>> seq_block_traces;
std::vector<std::vector<int32_t>> trailing_block_traces;
seq_block_traces.reserve(batch_size);
trailing_block_traces.reserve(batch_size);
for (int i = 0; i < batch_size; ++i) {
std::vector<int32_t> trace = sequences[i]->GetBlockTrace(global_block_pool);
if (static_cast<int>(trace.size()) <= kPagedKVCacheMaxBlockDepth) {
seq_block_traces.push_back(std::vector<int32_t>(trace.begin(), trace.end()));
trailing_block_traces.push_back({});
num_depths = std::max(num_depths, static_cast<int64_t>(trace.size()));
} else {
seq_block_traces.push_back(
std::vector<int32_t>(trace.begin(), trace.begin() + kPagedKVCacheMaxBlockDepth));
trailing_block_traces.push_back(
std::vector<int32_t>(trace.begin() + kPagedKVCacheMaxBlockDepth, trace.end()));
num_depths = std::max(num_depths, static_cast<int64_t>(kPagedKVCacheMaxBlockDepth));
}
}
// "Transpose" the traces, yielding the block ids used on each depth.
std::vector<std::vector<int32_t>> block_ids_on_depths;
block_ids_on_depths.reserve(num_depths);
for (int d = 0; d < num_depths; ++d) {
std::vector<int32_t> block_ids;
block_ids.reserve(batch_size);
for (int i = 0; i < batch_size; ++i) {
block_ids.push_back(d < static_cast<int>(seq_block_traces[i].size()) ? seq_block_traces[i][d]
: -1);
}
block_ids_on_depths.push_back(std::move(block_ids));
}
return {block_ids_on_depths, trailing_block_traces};
}
/*!
* \brief This function considers an optimization which coalesces
* adjacent decode attention computations into a single prefill
* attention computation if the adjacent decodes attend to the same
* k/v values under certain conditions.
* If it decides to coalesce on a certain depth, we need to know
* the prefill length after coalescing. This function returns
* - a vector of block ids together with the prefill/decode lengths
* that attend to the blocks.
* - a boolean indicating whether to use decode kernel on for the
* input blocks.
*/
inline std::pair<std::vector<std::pair<int32_t, int32_t>>, bool> GetChunkedBlockIds(
const std::vector<int32_t>& block_ids, bool enable_coalesce, const IntTuple& append_lengths,
const std::vector<Block>& global_block_pool, bool is_decode_request) {
std::vector<std::pair<int32_t, int32_t>> uncoalesced_block_ids;
std::vector<std::pair<int32_t, int32_t>> coalesced_block_ids;
// Gather the number of pages before/after coalescing respectively.
int cur_block_id = block_ids[0];
int chunk_append_length = append_lengths[0];
int page_counter_coalesced = 0;
int page_counter_uncoalesced =
block_ids[0] != -1 ? global_block_pool[block_ids[0]].page_ids.size() : 0;
for (int i = 1; i < static_cast<int>(block_ids.size()); ++i) {
if (block_ids[i] != -1) {
page_counter_uncoalesced += global_block_pool[block_ids[i]].page_ids.size();
}
uncoalesced_block_ids.emplace_back(block_ids[i - 1], append_lengths[i - 1]);
if (block_ids[i] == cur_block_id) {
chunk_append_length += append_lengths[i];
} else {
coalesced_block_ids.emplace_back(cur_block_id, chunk_append_length);
if (cur_block_id != -1) {
page_counter_coalesced += global_block_pool[cur_block_id].page_ids.size();
}
cur_block_id = block_ids[i];
chunk_append_length = append_lengths[i];
}
}
uncoalesced_block_ids.emplace_back(block_ids.back(), append_lengths.back());
coalesced_block_ids.emplace_back(cur_block_id, chunk_append_length);
if (cur_block_id != -1) {
page_counter_coalesced += global_block_pool[cur_block_id].page_ids.size();
}
double coalesce_ratio =
page_counter_coalesced > 0 ? 1.0 * page_counter_uncoalesced / page_counter_coalesced : 0.0;
// Do not coalesce and use batch decode kernel when coalesce ratio is small.
bool use_decode_kernel = is_decode_request && coalesce_ratio < 32;
return {use_decode_kernel || !enable_coalesce ? uncoalesced_block_ids : coalesced_block_ids,
use_decode_kernel};
}
/*!
* \brief The rotary embedding mode adopted by the paged KV cache
* when computing attention.
* "None" means RoPE is never applied to q and k.
* "Normal" means RoPE is computed in a standalone kernel.
* "Inline" means RoPE is computed on-the-fly in attention kernels.
*/
enum class RoPEMode : int {
kNone = 0,
kNormal = 1,
kInline = 2,
};
/*!
* \brief The class of host memory int32 vector in "std::vector" interface.
* This vector allocates static memory on the specified host memory
* at the time of construction.
*/
class HostMemoryVector {
public:
HostMemoryVector() = default;
HostMemoryVector(const HostMemoryVector&) = delete;
HostMemoryVector(HostMemoryVector&& other) = default;
HostMemoryVector& operator=(const HostMemoryVector&) = delete;
HostMemoryVector& operator=(HostMemoryVector&& other) = default;
explicit HostMemoryVector(int64_t reserved_size, DLDataType dtype, Device device)
: reserved_size_(reserved_size) {
TVM_FFI_ICHECK(DataType(dtype) == DataType::Int(32));
data_ = Tensor::Empty({reserved_size}, dtype, device);
}
void push_back(int32_t value) {
TVM_FFI_ICHECK_LE(current_size_, reserved_size_);
if (current_size_ == reserved_size_) {
reserved_size_ *= 2;
Tensor new_data = Tensor::Empty({reserved_size_}, data_->dtype, data_->device);
std::memcpy(new_data->data, data_->data, current_size_ * DataType(data_->dtype).bytes());
data_ = new_data;
}
static_cast<int32_t*>(data_->data)[current_size_++] = value;
}
const int32_t& operator[](int64_t idx) const {
TVM_FFI_ICHECK_GE(idx, 0) << "Index " << idx << " is negative.";
TVM_FFI_ICHECK_LT(idx, current_size_) << "Index " << idx << " out of bounds " << current_size_;
return static_cast<int32_t*>(data_->data)[idx];
}
int32_t back() const {
TVM_FFI_ICHECK_GT(current_size_, 0) << "Vector is empty";
return static_cast<int32_t*>(data_->data)[current_size_ - 1];
}
size_t size() const { return static_cast<size_t>(current_size_); }
int32_t* data() const { return static_cast<int32_t*>(data_->data); }
void clear() { current_size_ = 0; }
/*! \brief Return the vector as an Tensor. */
Tensor as_tensor() { return data_.CreateView({current_size_}, data_->dtype); }
IntTuple as_int_tuple() const {
std::vector<int64_t> values;
values.reserve(current_size_);
for (int i = 0; i < current_size_; ++i) {
values.push_back(static_cast<int32_t*>(data_->data)[i]);
}
return IntTuple(values);
}
private:
int64_t reserved_size_ = 0;
int64_t current_size_ = 0;
Tensor data_{nullptr};
};
/*!
* \brief The paged attention auxiliary data manager class.
* This class manages all the int32 auxiliary data on GPU device, such as
* page table, position arrays, etc..
*
* The core functions of this class is `CopyXXXAsync` and `CommitAttnAuxDataCopy`.
* `CopyXXXAsync` takes the input data on CPU host, and copy the input data
* to GPU in an asynchronous way, and returns the Tensor view of the data
* on GPU device.
*
* Being asynchronous here means the `CopyXXXAsync` function may not perform
* data copy from CPU to GPU at the time of being called. Therefore, the
* returned Tensor view may have wrong result, until `CommitAttnAuxDataCopy` is
* explicitly invoked and the data copy stream is synchronized.
*
* We design this manager class in order to reduce the data copy overhead.
*/
class PagedKVCacheAuxDataManager {
public:
PagedKVCacheAuxDataManager(DLDataType dtype_aux, Device device, Device preferred_host_device,
TVMStreamHandle copy_stream)
: dtype_aux_(dtype_aux),
device_(device),
preferred_host_device_(preferred_host_device),
copy_stream_(copy_stream) {
TVM_FFI_ICHECK(DataType(dtype_aux) == DataType::Int(32));
}
virtual ~PagedKVCacheAuxDataManager() = default;
/*! \brief Reset the attention auxiliary data status of copy manager. */
virtual void ResetAttnAuxDataCopy() = 0;
/*! \brief Copy the indptr array of append lengths after coalescing. (see GetChunkedBlockIds) */
virtual Tensor CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0;
/*! \brief Copy the indptr array of page table. */
virtual Tensor CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0;
/*! \brief Copy the indices array of page table. */
virtual Tensor CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) = 0;
/*! \brief Copy the array of KV slot number used in the last page of the seq. */
virtual Tensor CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) = 0;
/*!
* \brief Copy the length information of the sequences.
* Each Tensor is in shape `(3, n)`. "n" is the number of sequences.
* For a sequence "i", location
* - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"),
* - "(1, i)" is the starting offset of the sliding window in the seq,
* - "(2, i)" is the attn sink length of the sequence.
* \note When sliding window is not enabled, only the
* "last_page_len" (a.k.a., the first "n" elements) will be effectively used.
*/
virtual Tensor CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len,
HostMemoryVector* sliding_window_offset,
HostMemoryVector* sink_size, int depth) = 0;
/*! \brief Copy the k position offset of applying RoPE for each sequence. */
virtual Tensor CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) = 0;
/*!
* \brief Copy the append length indptr array on device.
* \note Since the Q/K/V data may have raggedness in terms of lengths,
* we represent the append lengths in CSR format.
*/
virtual Tensor CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) = 0;
/*! \brief Copy the k position offset of applying RoPE for each sequence. */
virtual Tensor CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) = 0;
/*! \brief Copy the q position mapping of applying RoPE for each sequence. */
virtual Tensor CopyQRoPEPosMapAsync(HostMemoryVector* data) = 0;
/*!
* \brief Copy the corresponding position in global KV cache (pages)
* for each position along the length dimension of K/V data when
* appending new K/V data.
*/
virtual Tensor CopyAppendPositionMapAsync(HostMemoryVector* data) = 0;
/*! \brief Copy the remote position map for KV transfer. */
virtual Tensor CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) = 0;
/*! \brief Copy the receiver id for KV transfer. */
virtual Tensor CopyKVTransferRecverIDAsync(HostMemoryVector* data) = 0;
/*! \brief Copy the local position map for KV page-to-page transfer. */
virtual Tensor CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) = 0;
/*! \brief Copy the remote position map for KV page-to-page transfer. */
virtual Tensor CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) = 0;
/*! \brief Copy the receiver id for KV page-to-page transfer. */
virtual Tensor CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) = 0;
/*! \brief Copy the tree attention mask. */
virtual Tensor CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) = 0;
/*! \brief Copy the mn indptr of the tree attention mask. */
virtual Tensor CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0;
/*! \brief Commit all the attention auxiliary data copy operations since the last commit. */
virtual void CommitAttnAuxDataCopy() = 0;
/*! \brief Reset the compact KV auxiliary data status of copy manager. */
virtual void ResetCompactKVAuxDataCopy() = 0;
/*! \brief Copy the length indptr array of KV data copy for each sequence. */
virtual Tensor CopyCommitLengthIndptrAsync(HostMemoryVector* data) = 0;
/*! \brief Copy the src/dst position arrays for each sequence. */
virtual Tensor CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data,
HostMemoryVector* dst_data) = 0;
/*! \brief Commit all the compact KV auxiliary data copy operations since the last commit. */
virtual void CommitCompactKVAuxDataCopy() = 0;
protected:
/*! \brief The dtype of the auxiliary data. It is expected to be int32. */
const DLDataType dtype_aux_;
/*! \brief The device this PagedKVCache runs on. */
const Device device_;
/*! \brief The preferred host device. */
const Device preferred_host_device_;
/*! \brief The device stream for copying auxiliary data structure to GPU. */
const TVMStreamHandle copy_stream_;
};
/*!
* \brief The plain auxiliary data manager class.
* It simply issues one host-to-device copy operation for each `CopyXXXAsync`.
*/
class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager {
public:
explicit PlainPagedKVCacheAuxDataManager(int64_t reserved_num_seqs, int64_t num_total_pages,
int64_t prefill_chunk_size, DLDataType dtype_aux,
Device device, Device preferred_host_device,
TVMStreamHandle copy_stream)
: PagedKVCacheAuxDataManager(dtype_aux, device, preferred_host_device, copy_stream) {
for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
qo_indptr_on_depths_device_.push_back(
Tensor::Empty({reserved_num_seqs + 1}, dtype_aux_, device));
page_indptr_on_depths_device_.push_back(
Tensor::Empty({reserved_num_seqs + 1}, dtype_aux_, device));
page_indices_on_depths_device_.push_back(
Tensor::Empty({num_total_pages}, dtype_aux_, device));
length_info_on_depths_device_.push_back(
Tensor::Empty({3, reserved_num_seqs}, dtype_aux_, device));
k_rope_pos_offset_on_depths_device_.push_back(
Tensor::Empty({reserved_num_seqs}, dtype_aux_, device));
tree_attn_mask_device_.push_back(Tensor::Empty(
{kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs}, dtype_aux_, device));
tree_attn_mn_indptr_device_.push_back(
Tensor::Empty({reserved_num_seqs + 1}, dtype_aux_, device));
}
cur_append_length_indptr_device_ = Tensor::Empty({reserved_num_seqs + 1}, dtype_aux_, device);
k_ragged_rope_pos_offset_device_ = Tensor::Empty({reserved_num_seqs}, dtype_aux_, device);
q_rope_position_map_device_ = Tensor::Empty({prefill_chunk_size}, dtype_aux_, device);
append_position_map_device_ = Tensor::Empty({prefill_chunk_size}, dtype_aux_, device);
kv_transfer_remote_position_map_device =
Tensor::Empty({prefill_chunk_size}, dtype_aux_, device);
kv_transfer_recver_id_device = Tensor::Empty({prefill_chunk_size}, dtype_aux_, device);
kv_transfer_page_to_page_local_position_map_device =
kv_transfer_page_to_page_remote_position_map_device =
Tensor::Empty({prefill_chunk_size}, dtype_aux_, device);
kv_transfer_page_to_page_recver_id_device =
Tensor::Empty({prefill_chunk_size}, dtype_aux_, device);
commit_copy_length_indptr_device_ = Tensor::Empty({reserved_num_seqs + 1}, dtype_aux_, device);
commit_copy_src_dst_pos_in_page_table_device_ =
Tensor::Empty({2, std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size)},
dtype_aux_, device);
}
// The reset of the plain auxiliary data manager is no-op.
void ResetAttnAuxDataCopy() final {}
Tensor CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final {
Tensor view = qo_indptr_on_depths_device_[depth].CreateView(
{static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final {
Tensor view = page_indptr_on_depths_device_[depth].CreateView(
{static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final {
Tensor view = page_indices_on_depths_device_[depth].CreateView(
{static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final {
Tensor view = length_info_on_depths_device_[depth].CreateView(
{static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final {
Tensor view = k_rope_pos_offset_on_depths_device_[depth].CreateView(
{static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final {
Tensor view = cur_append_length_indptr_device_.CreateView({static_cast<int64_t>(data->size())},
dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final {
Tensor view = k_ragged_rope_pos_offset_device_.CreateView({static_cast<int64_t>(data->size())},
dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyQRoPEPosMapAsync(HostMemoryVector* data) final {
Tensor view =
q_rope_position_map_device_.CreateView({static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyAppendPositionMapAsync(HostMemoryVector* data) final {
Tensor view =
append_position_map_device_.CreateView({static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final {
Tensor view = kv_transfer_remote_position_map_device.CreateView(
{static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyKVTransferRecverIDAsync(HostMemoryVector* data) final {
Tensor view =
kv_transfer_recver_id_device.CreateView({static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) final {
Tensor view = kv_transfer_page_to_page_local_position_map_device.CreateView(
{static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) final {
Tensor view = kv_transfer_page_to_page_remote_position_map_device.CreateView(
{static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) final {
Tensor view = kv_transfer_page_to_page_recver_id_device.CreateView(
{static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) final {
Tensor view =
tree_attn_mask_device_[depth].CreateView({static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) final {
Tensor view = tree_attn_mn_indptr_device_[depth].CreateView(
{static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len,
HostMemoryVector* sliding_window_offset,
HostMemoryVector* sink_size, int depth) final {
int n_elem = last_page_len->size();
TVM_FFI_ICHECK_GT(n_elem, 0);
Tensor view = length_info_on_depths_device_[depth].CreateView({3, n_elem}, dtype_aux_);
ffi::Shape copy_shape{n_elem};
CopyVecDataToArray(view, last_page_len->data(), copy_shape);
CopyVecDataToArray(view, sliding_window_offset->data(), copy_shape,
/*dst_elem_offset=*/n_elem);
CopyVecDataToArray(view, sink_size->data(), copy_shape,
/*dst_elem_offset=*/2 * n_elem);
return view;
}
// The commit of the plain auxiliary data manager is no-op.
void CommitAttnAuxDataCopy() final {}
// The reset of the plain auxiliary data manager is no-op.
void ResetCompactKVAuxDataCopy() final {}
Tensor CopyCommitLengthIndptrAsync(HostMemoryVector* data) final {
Tensor view = commit_copy_length_indptr_device_.CreateView({static_cast<int64_t>(data->size())},
dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Tensor CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data,
HostMemoryVector* dst_data) final {
int n_elem = src_data->size();
TVM_FFI_ICHECK_GT(n_elem, 0);
Tensor view = commit_copy_src_dst_pos_in_page_table_device_.CreateView({2, n_elem}, dtype_aux_);
ffi::Shape copy_shape{n_elem};
CopyVecDataToArray(view, src_data->data(), copy_shape);
CopyVecDataToArray(view, dst_data->data(), copy_shape,
/*dst_elem_offset=*/n_elem);
return view;
}
// The commit of the plain auxiliary data manager is no-op.
void CommitCompactKVAuxDataCopy() final {}
private:
/*!
* \brief Copy a vector of data to the input Tensor.
* It optionally supports specifying the shape of copy and the element
* offset to the destination Tensor.
*/
void CopyVecDataToArray(Tensor array, int32_t* vec_data,
ffi::Optional<ffi::Shape> shape = std::nullopt, int dst_elem_offset = 0) {
if (array->shape[0] == 0) {
return;
}
DLTensor copy_dst = *array.operator->();
#if defined(OPENCL_ENABLE_HOST_PTR)
tvm::runtime::cl::OpenCLWorkspace* workspace = tvm::runtime::cl::OpenCLWorkspace::Global();
if (workspace->IsOpenCLDevice(copy_dst.device)) {
void* nptr = workspace->GetNativePtr(array);
uint64_t copy_size;
if (shape.defined()) {
TVM_FFI_ICHECK_EQ(shape.value().size(), 1);
copy_size = shape.value()->data[0] * sizeof(int32_t);
} else {
copy_size = DeviceAPI::Get(array->device)->GetDataSize(*array.operator->());
}
memcpy(static_cast<char*>(nptr) + dst_elem_offset * sizeof(int32_t), vec_data, copy_size);
return;
}
#endif
if (shape.defined()) {
TVM_FFI_ICHECK_EQ(shape.value().size(), 1);
copy_dst.ndim = 1;
copy_dst.shape = const_cast<int64_t*>(shape.value()->data);
}
copy_dst.byte_offset = dst_elem_offset * sizeof(int32_t);
DLTensor copy_src;
copy_src.data = vec_data;
copy_src.device = preferred_host_device_;
copy_src.ndim = 1;
copy_src.dtype = array->dtype;
copy_src.shape = copy_dst.shape;
copy_src.strides = nullptr;
copy_src.byte_offset = 0;
Tensor::CopyFromTo(&copy_src, &copy_dst, copy_stream_);
}
std::vector<Tensor> qo_indptr_on_depths_device_;
std::vector<Tensor> page_indptr_on_depths_device_;
std::vector<Tensor> page_indices_on_depths_device_;
std::vector<Tensor> length_info_on_depths_device_;
std::vector<Tensor> k_rope_pos_offset_on_depths_device_;
std::vector<Tensor> tree_attn_mask_device_;
std::vector<Tensor> tree_attn_mn_indptr_device_;
Tensor cur_append_length_indptr_device_;
Tensor k_ragged_rope_pos_offset_device_;
Tensor q_rope_position_map_device_;
Tensor append_position_map_device_;
Tensor kv_transfer_remote_position_map_device;
Tensor kv_transfer_recver_id_device;
Tensor kv_transfer_page_to_page_local_position_map_device;
Tensor kv_transfer_page_to_page_remote_position_map_device;
Tensor kv_transfer_page_to_page_recver_id_device;
Tensor commit_copy_length_indptr_device_;
Tensor commit_copy_src_dst_pos_in_page_table_device_;
};
/*!
* \brief The cached auxiliary data manager class.
* It allocates a large on-device array to store all the auxiliary data.
* For each `CopyXXXAsync`, it copies the input data to a local cache on host.
* In `CommitAttnAuxDataCopy`, it copies all the data in the local cache to the device
* array for a single time, and thus reduce the number of host-to-device copies needed.
*/
class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager {
public:
explicit CachedPagedKVCacheAuxDataManager(int64_t reserved_num_seqs, int64_t num_total_pages,
int64_t prefill_chunk_size, DLDataType dtype_aux,
Device device, Device preferred_host_device,
TVMStreamHandle copy_stream)
: PagedKVCacheAuxDataManager(dtype_aux, device, preferred_host_device, copy_stream),
elem_byte_size_((dtype_aux.bits * dtype_aux.lanes + 7) / 8),
offset_alignment_(cuda_byte_alignment_ / elem_byte_size_) {
// - Calculate cache size of all the attention auxiliary arrays in
// local cache and the large on-device array.
int64_t attn_aux_data_cache_size =
CalculateAttnAuxDataCacheSize(reserved_num_seqs, num_total_pages, prefill_chunk_size);
// - Initialize the host auxiliary data buffer.
merged_attn_aux_data_host_ =
HostMemoryVector(attn_aux_data_cache_size, dtype_aux, preferred_host_device);
// - Initialize the device auxiliary data buffer.
merged_attn_aux_data_device_ = Tensor::Empty({attn_aux_data_cache_size}, dtype_aux, device);
// - Calculate cache size of all the compact KV auxiliary arrays in
// local cache and the large on-device array.
int64_t compact_kv_aux_data_cache_size =
CalculateCompactKVAuxDataCacheSize(reserved_num_seqs, prefill_chunk_size);
// - Initialize the host auxiliary data buffer.
merged_compact_kv_aux_data_host_ =
HostMemoryVector(compact_kv_aux_data_cache_size, dtype_aux, preferred_host_device);
merged_compact_kv_aux_data_device_ =
Tensor::Empty({compact_kv_aux_data_cache_size}, dtype_aux, device);
}
void ResetAttnAuxDataCopy() final { attn_aux_data_copy_offset_ = 0; }
Tensor CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final {
return CopyAttnAuxVecToCache(data);
}
Tensor CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final {
return CopyAttnAuxVecToCache(data);
}
Tensor CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final {
return CopyAttnAuxVecToCache(data);
}
Tensor CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final {
return CopyAttnAuxVecToCache(data);
}
Tensor CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final {
return CopyAttnAuxVecToCache(data);
}
Tensor CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final {
return CopyAttnAuxVecToCache(data);
}
Tensor CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final {
return CopyAttnAuxVecToCache(data);
}
Tensor CopyQRoPEPosMapAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); }
Tensor CopyAppendPositionMapAsync(HostMemoryVector* data) final {
return CopyAttnAuxVecToCache(data);
}
Tensor CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final {
return CopyAttnAuxVecToCache(data);
}
Tensor CopyKVTransferRecverIDAsync(HostMemoryVector* data) final {
return CopyAttnAuxVecToCache(data);
}
Tensor CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) final {
return CopyAttnAuxVecToCache(data);
}
Tensor CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) final {
return CopyAttnAuxVecToCache(data);
}
Tensor CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) final {
return CopyAttnAuxVecToCache(data);
}
Tensor CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) final {
Tensor mask_1d = CopyAttnAuxVecToCache(data);
return mask_1d.CreateView({static_cast<int64_t>(data->size() / 2), 2}, mask_1d->dtype);
}
Tensor CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) final {
return CopyAttnAuxVecToCache(data);
}
Tensor CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len,
HostMemoryVector* sliding_window_offset,
HostMemoryVector* sink_size, int depth) final {
int64_t n_elem = last_page_len->size();
std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_,
last_page_len->data(), n_elem * elem_byte_size_);
std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_ + n_elem,
sliding_window_offset->data(), n_elem * elem_byte_size_);
std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_ + 2 * n_elem,
sink_size->data(), n_elem * elem_byte_size_);
Tensor view =
Tensor::FromNDAlloc(ViewHelper(merged_attn_aux_data_device_), ffi::Shape({3, n_elem}),
dtype_aux_, device_, attn_aux_data_copy_offset_ * elem_byte_size_);
attn_aux_data_copy_offset_ += CeilDivElemAlignment(3 * n_elem);
return view;
}
void CommitAttnAuxDataCopy() final {
std::vector<int64_t> copy_shape{attn_aux_data_copy_offset_};
DLTensor copy_dst;
copy_dst.data = merged_attn_aux_data_device_->data;
copy_dst.device = device_;
copy_dst.ndim = 1;
copy_dst.dtype = dtype_aux_;
copy_dst.shape = copy_shape.data();
copy_dst.strides = nullptr;
copy_dst.byte_offset = 0;
DLTensor copy_src = copy_dst;
copy_src.data = merged_attn_aux_data_host_.data();
copy_src.device = Device{kDLCPU, 0};
Tensor::CopyFromTo(&copy_src, &copy_dst, copy_stream_);
}
void ResetCompactKVAuxDataCopy() final { compact_kv_aux_data_copy_offset_ = 0; }
Tensor CopyCommitLengthIndptrAsync(HostMemoryVector* data) final {
return CopyCompactKVAuxVecToCache(data);
}
Tensor CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data,
HostMemoryVector* dst_data) final {
int64_t n_elem = src_data->size();
std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_,
src_data->data(), n_elem * elem_byte_size_);
std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_ + n_elem,
dst_data->data(), n_elem * elem_byte_size_);
Tensor view = Tensor::FromNDAlloc(ViewHelper(merged_compact_kv_aux_data_device_),
ffi::Shape({2, n_elem}), dtype_aux_, device_,
compact_kv_aux_data_copy_offset_ * elem_byte_size_);
compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(2 * n_elem);
return view;
}
void CommitCompactKVAuxDataCopy() final {
std::vector<int64_t> copy_shape{compact_kv_aux_data_copy_offset_};
DLTensor copy_dst;
copy_dst.data = merged_compact_kv_aux_data_device_->data;
copy_dst.device = device_;
copy_dst.ndim = 1;
copy_dst.dtype = dtype_aux_;
copy_dst.shape = copy_shape.data();
copy_dst.strides = nullptr;
copy_dst.byte_offset = 0;
DLTensor copy_src = copy_dst;
copy_src.data = merged_compact_kv_aux_data_host_.data();
copy_src.device = Device{kDLCPU, 0};
Tensor::CopyFromTo(&copy_src, &copy_dst, copy_stream_);
}
private:
// helper allocator class that applies byte offset to the original data pointer
class ViewHelper {
public:
explicit ViewHelper(Tensor source) : source_(source) {}
void AllocData(DLTensor* tensor, int64_t extra_byte_offset) {
tensor->data = static_cast<char*>(source_->data) + extra_byte_offset;
}
void FreeData(DLTensor* tensor) {}
private:
Tensor source_;
};
/*!
* \brief Calculate the start element offsets of the auxiliary arrays in the local cache.
* \return Return the local cache size (total number of elements in the local cache).
*/
int64_t CalculateAttnAuxDataCacheSize(int64_t reserved_num_seqs, int64_t num_total_pages,
int64_t prefill_chunk_size) {
int64_t cache_size = 0;
// - Array size of the arrays that every depth has.
// Corresponding to the following arrays respectively
// - qo_indptr_in_depth
// - page_indptr_in_depth
// - page_indices_in_depth
// - length_info_in_depth
// - k_rope_pos_offset_in_depth
cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
cache_size += CeilDivElemAlignment(num_total_pages);
cache_size += CeilDivElemAlignment(3 * reserved_num_seqs);
cache_size += CeilDivElemAlignment(reserved_num_seqs);
cache_size *= kPagedKVCacheMaxBlockDepth;
// - Array size of other arrays.
// Corresponding to the following arrays respectively
// - cur_append_length_indptr
// - k_ragged_rope_pos_offset
// - q_rope_position_map
// - append_position_map
// - kv_transfer_remote_position_map
// - kv_transfer_recver_id
// - kv_transfer_page_to_page_local_position_map
// - kv_transfer_page_to_page_remote_position_map
// - kv_transfer_page_to_page_recver_id
// - tree_attn_mask
// - tree_attn_mn_indptr
cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
cache_size += CeilDivElemAlignment(reserved_num_seqs);
cache_size += CeilDivElemAlignment(prefill_chunk_size);
cache_size += CeilDivElemAlignment(prefill_chunk_size);
cache_size += CeilDivElemAlignment(prefill_chunk_size);
cache_size += CeilDivElemAlignment(prefill_chunk_size);
cache_size += CeilDivElemAlignment(prefill_chunk_size);
cache_size += CeilDivElemAlignment(prefill_chunk_size);
cache_size += CeilDivElemAlignment(prefill_chunk_size);
cache_size +=
CeilDivElemAlignment(kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs);
cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
return cache_size;
}
int64_t CalculateCompactKVAuxDataCacheSize(int64_t reserved_num_seqs,
int64_t prefill_chunk_size) {
int64_t cache_size = 0;
// Corresponding to the following arrays respectively
// - commit_copy_length_indptr
// - commit_copy_src_dst_pos_in_page_table
cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
cache_size += CeilDivElemAlignment(
2 * std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size));
return cache_size;
}
/*!
* \brief Copy the input data to the cache at the given offset.
* And return the Tensor view of the cache starting at the offset.
*/
Tensor CopyAttnAuxVecToCache(HostMemoryVector* data) {
int64_t n_elem = data->size();
std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_, data->data(),
n_elem * elem_byte_size_);
Tensor view =
Tensor::FromNDAlloc(ViewHelper(merged_attn_aux_data_device_), ffi::Shape({n_elem}),
dtype_aux_, device_, attn_aux_data_copy_offset_ * elem_byte_size_);
attn_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem);
return view;
}
Tensor CopyCompactKVAuxVecToCache(HostMemoryVector* data) {
int64_t n_elem = data->size();
std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_,
data->data(), n_elem * elem_byte_size_);
Tensor view = Tensor::FromNDAlloc(ViewHelper(merged_compact_kv_aux_data_device_),
ffi::Shape({n_elem}), dtype_aux_, device_,
compact_kv_aux_data_copy_offset_ * elem_byte_size_);
compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem);
return view;
}
/*! \brief For safety, we align the start offset of the arrays to `offset_alignment`. */
int64_t CeilDivElemAlignment(int n) {
return (n + offset_alignment_ - 1) / offset_alignment_ * offset_alignment_;
}
const int64_t cuda_byte_alignment_ = 16;
const int64_t elem_byte_size_;
const int64_t offset_alignment_;
int64_t attn_aux_data_copy_offset_ = 0;
int64_t compact_kv_aux_data_copy_offset_ = 0;
HostMemoryVector merged_attn_aux_data_host_;
HostMemoryVector merged_compact_kv_aux_data_host_;
Tensor merged_attn_aux_data_device_;
Tensor merged_compact_kv_aux_data_device_;
};
} // namespace vm
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_VM_ATTN_UTILS_H_