[KVCache] Support KVCache decode from forked sequence and pop more tokens (#16995)

diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc
index b07ae3d..a5d2d9f 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -925,10 +925,21 @@
     if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
       // Fork at last by appending a new block directly
       int32_t parent_block_idx = parent_it->second.last_block_idx;
+      if (!global_block_pool_[parent_block_idx].seq_length) {
+        // If parent ends with empty block, fork from parent's parent block
+        parent_block_idx = global_block_pool_[parent_block_idx].parent_idx;
+      }
       ++global_block_pool_[parent_block_idx].external_ref_cnt;
       // Update child block start position and parent index
       global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length;
       global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
+      if (global_block_pool_[parent_block_idx].seq_length) {
+        // If parent is not empty, append a new block
+        int32_t new_parent_block_idx = GetFreeBlock();
+        global_block_pool_[new_parent_block_idx].start_pos = parent_it->second.seq_length;
+        global_block_pool_[new_parent_block_idx].parent_idx = parent_block_idx;
+        parent_it->second.last_block_idx = new_parent_block_idx;
+      }
     } else {
       // Locate the block to fork from and calculate in-block offset
       std::vector<int32_t> trace = parent_it->second.GetBlockTrace(global_block_pool_);
@@ -1038,21 +1049,51 @@
     auto it = seq_map_.find(seq_id);
     CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
 
-    Block& block = global_block_pool_[it->second.last_block_idx];
     CHECK_GE(n, 0) << "The length of popping " << n << " cannot be negative.";
-    CHECK_LE(n, block.seq_length) << "The sequence only has length " << block.seq_length
-                                  << " in the last block, while the length of pop is " << n
-                                  << " which exceeds the last-block sequence length.";
-
-    int64_t cur_npage = block.page_ids.size();
-    int64_t tgt_npage = (block.seq_length - n + page_size_ - 1) / page_size_;
-    while (cur_npage > tgt_npage) {
-      free_page_ids_.push_back(block.page_ids.back());
-      block.page_ids.pop_back();
-      --cur_npage;
+    CHECK_LE(n, it->second.seq_length)
+        << "The sequence only has length " << it->second.seq_length
+        << ", while the length of pop is " << n << " which exceeds the whole sequence length.";
+    int32_t block_idx = it->second.last_block_idx;
+    while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) {
+      if (n > global_block_pool_[block_idx].seq_length) {
+        n -= global_block_pool_[block_idx].seq_length;
+        it->second.seq_length -= global_block_pool_[block_idx].seq_length;
+        for (int32_t page_id : global_block_pool_[block_idx].page_ids) {
+          free_page_ids_.push_back(page_id);
+        }
+        free_block_idx_.push_back(block_idx);
+        block_idx = global_block_pool_[block_idx].parent_idx;
+        it->second.last_block_idx = block_idx;
+        continue;
+      }
+      if (n <= global_block_pool_[block_idx].seq_length) {
+        int64_t cur_npage = global_block_pool_[block_idx].page_ids.size();
+        int64_t tgt_npage =
+            (global_block_pool_[block_idx].seq_length - n + page_size_ - 1) / page_size_;
+        while (cur_npage > tgt_npage) {
+          free_page_ids_.push_back(global_block_pool_[block_idx].page_ids.back());
+          global_block_pool_[block_idx].page_ids.pop_back();
+          --cur_npage;
+        }
+        it->second.seq_length -= n;
+        global_block_pool_[block_idx].seq_length -= n;
+        n = 0;
+        break;
+      }
     }
-    it->second.seq_length -= n;
-    block.seq_length -= n;
+
+    if (n) {
+      int32_t temp_seq_id = -1 - seq_id;
+      CHECK(seq_map_.find(temp_seq_id) == seq_map_.end());
+      ForkSequence(seq_id, temp_seq_id, it->second.seq_length - n);
+      CHECK(seq_map_.find(temp_seq_id) != seq_map_.end());
+      RemoveSequence(seq_id);
+      CHECK(seq_map_.find(seq_id) == seq_map_.end());
+      auto it = seq_map_.find(temp_seq_id);
+      seq_map_.insert({seq_id, Sequence(global_block_pool_, it->second.last_block_idx)});
+      seq_map_.erase(temp_seq_id);
+    }
+
     dirty_aux_data_device_ = true;
   }