[KVCache] Support KVCache decode from forked sequence and pop more tokens (#16995)
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc
index b07ae3d..a5d2d9f 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -925,10 +925,21 @@
if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
// Fork at last by appending a new block directly
int32_t parent_block_idx = parent_it->second.last_block_idx;
+ if (!global_block_pool_[parent_block_idx].seq_length) {
+ // If parent ends with empty block, fork from parent's parent block
+ parent_block_idx = global_block_pool_[parent_block_idx].parent_idx;
+ }
++global_block_pool_[parent_block_idx].external_ref_cnt;
// Update child block start position and parent index
global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
+ if (global_block_pool_[parent_block_idx].seq_length) {
+ // If parent is not empty, append a new block
+ int32_t new_parent_block_idx = GetFreeBlock();
+ global_block_pool_[new_parent_block_idx].start_pos = parent_it->second.seq_length;
+ global_block_pool_[new_parent_block_idx].parent_idx = parent_block_idx;
+ parent_it->second.last_block_idx = new_parent_block_idx;
+ }
} else {
// Locate the block to fork from and calculate in-block offset
std::vector<int32_t> trace = parent_it->second.GetBlockTrace(global_block_pool_);
@@ -1038,21 +1049,51 @@
auto it = seq_map_.find(seq_id);
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
- Block& block = global_block_pool_[it->second.last_block_idx];
CHECK_GE(n, 0) << "The length of popping " << n << " cannot be negative.";
- CHECK_LE(n, block.seq_length) << "The sequence only has length " << block.seq_length
- << " in the last block, while the length of pop is " << n
- << " which exceeds the last-block sequence length.";
-
- int64_t cur_npage = block.page_ids.size();
- int64_t tgt_npage = (block.seq_length - n + page_size_ - 1) / page_size_;
- while (cur_npage > tgt_npage) {
- free_page_ids_.push_back(block.page_ids.back());
- block.page_ids.pop_back();
- --cur_npage;
+ CHECK_LE(n, it->second.seq_length)
+ << "The sequence only has length " << it->second.seq_length
+ << ", while the length of pop is " << n << " which exceeds the whole sequence length.";
+ int32_t block_idx = it->second.last_block_idx;
+ while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) {
+ if (n > global_block_pool_[block_idx].seq_length) {
+ n -= global_block_pool_[block_idx].seq_length;
+ it->second.seq_length -= global_block_pool_[block_idx].seq_length;
+ for (int32_t page_id : global_block_pool_[block_idx].page_ids) {
+ free_page_ids_.push_back(page_id);
+ }
+ free_block_idx_.push_back(block_idx);
+ block_idx = global_block_pool_[block_idx].parent_idx;
+ it->second.last_block_idx = block_idx;
+ continue;
+ }
+ if (n <= global_block_pool_[block_idx].seq_length) {
+ int64_t cur_npage = global_block_pool_[block_idx].page_ids.size();
+ int64_t tgt_npage =
+ (global_block_pool_[block_idx].seq_length - n + page_size_ - 1) / page_size_;
+ while (cur_npage > tgt_npage) {
+ free_page_ids_.push_back(global_block_pool_[block_idx].page_ids.back());
+ global_block_pool_[block_idx].page_ids.pop_back();
+ --cur_npage;
+ }
+ it->second.seq_length -= n;
+ global_block_pool_[block_idx].seq_length -= n;
+ n = 0;
+ break;
+ }
}
- it->second.seq_length -= n;
- block.seq_length -= n;
+
+ if (n) {
+ int32_t temp_seq_id = -1 - seq_id;
+ CHECK(seq_map_.find(temp_seq_id) == seq_map_.end());
+ ForkSequence(seq_id, temp_seq_id, it->second.seq_length - n);
+ CHECK(seq_map_.find(temp_seq_id) != seq_map_.end());
+ RemoveSequence(seq_id);
+ CHECK(seq_map_.find(seq_id) == seq_map_.end());
+ auto it = seq_map_.find(temp_seq_id);
+ seq_map_.insert({seq_id, Sequence(global_block_pool_, it->second.last_block_idx)});
+ seq_map_.erase(temp_seq_id);
+ }
+
dirty_aux_data_device_ = true;
}