blob: 7f20381b4845164958a0782f0c087cd3fc34c4b6 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import enum
import itertools
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import pytest
import scipy.special
import torch
import tvm
import tvm.testing
from tvm.s_tir import dlight as dl
from tvm.relax.frontend.nn.llm.kv_cache import (
AttnKind,
RopeMode,
_attention_decode,
_attention_prefill,
_attention_prefill_ragged,
_compact_kv_copy,
_copy_single_page,
_kv_cache_debug_get_kv,
_kv_cache_transpose_append,
_merge_state_inplace,
llama_rope_with_position_map,
tree_attn,
tree_attn_with_paged_kv_cache,
)
from tvm.runtime import ShapeTuple
def get_comm_rank():
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
return comm, rank
except ImportError:
return None, 0
comm, rank = get_comm_rank()
reserved_nseq = 32
maximum_total_seq_length = 2048
prefill_chunk_size = 512
page_size = 16
num_layers = 4
num_qo_heads = 32
num_kv_heads = 4
head_dim = None
sm_scale = None
rope_scale = 1.0
rope_theta = 1e4
rope_scaling = {}
dtype = None
dtype_torch = None
device = tvm.cuda(rank)
device_torch = torch.device(f"cuda:{rank}")
fclear = None
fadd_sequence = None
fremove_sequence = None
ffork_sequence = None
fenable_sliding_window_for_seq = None
fpopn = None
fbegin_forward = None
fend_forward = None
fcommit_accepted_token_tree_nodes = None
fattention_with_fuse_qkv = None
fis_empty = None
fdebug_get_kv = None
fnvshmem_get_uid = None
fnvshmem_init = None
fdisagg_mark_send = None
fdisagg_prepare_recv = None
ftranspose_append = None
fcopy_cache = None
fattn_prefill = None
fattn_decode = None
fattn_prefill_sliding_window = None
fattn_decode_sliding_window = None
fattn_prefill_ragged = None
fattn_prefill_with_tree_mask = None
fattn_prefill_with_tree_mask_paged_kv_cache = None
fmerge_state = None
fsplit_rotary = None
fattention_rotary = None
fcopy_single_page = None
fcompact_copy = None
def set_global_func(head_dim, dtype):
global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fenable_sliding_window_for_seq
global fpopn, fbegin_forward, fend_forward, fcommit_accepted_token_tree_nodes
global fattention_with_fuse_qkv, fis_empty, fdebug_get_kv
global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode
global fattn_prefill_ragged, fattn_prefill_with_tree_mask, fattn_prefill_with_tree_mask_paged_kv_cache
global fattn_prefill_sliding_window, fattn_decode_sliding_window
global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page, fcompact_copy
global fnvshmem_get_uid, fnvshmem_init, fdisagg_mark_send, fdisagg_prepare_recv
fclear = tvm.get_global_func("vm.builtin.kv_state_clear")
fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
fremove_sequence = tvm.get_global_func("vm.builtin.kv_state_remove_sequence")
ffork_sequence = tvm.get_global_func("vm.builtin.kv_state_fork_sequence")
fenable_sliding_window_for_seq = tvm.get_global_func(
"vm.builtin.attention_kv_cache_enable_sliding_window_for_seq"
)
fpopn = tvm.get_global_func("vm.builtin.kv_state_popn")
fbegin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward")
fend_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward")
fcommit_accepted_token_tree_nodes = tvm.get_global_func(
"vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes"
)
fattention_with_fuse_qkv = tvm.get_global_func(
"vm.builtin.attention_kv_cache_attention_with_fused_qkv"
)
fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty")
fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv")
fnvshmem_get_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
fnvshmem_init = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem")
fdisagg_mark_send = tvm.get_global_func("vm.builtin.kv_cache_disagg_mark_send")
fdisagg_prepare_recv = tvm.get_global_func("vm.builtin.kv_cache_disagg_prepare_recv")
target = tvm.target.Target.from_device(device)
builts = []
for tir_func in [
_kv_cache_transpose_append(num_kv_heads, head_dim, dtype),
_kv_cache_debug_get_kv(num_layers, num_kv_heads, head_dim, dtype),
_attention_prefill(
num_kv_heads, num_qo_heads, head_dim, dtype, False, rope_scaling, target
),
_attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, False, rope_scaling, target),
_attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype, True, rope_scaling, target),
_attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, True, rope_scaling, target),
_attention_prefill_ragged(
num_kv_heads, num_qo_heads, head_dim, head_dim, dtype, rope_scaling, target
),
tree_attn(num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling, target),
tree_attn_with_paged_kv_cache(
num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling, target
),
_merge_state_inplace(num_qo_heads, head_dim, dtype, target),
llama_rope_with_position_map(
rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype, rope_scaling
),
_copy_single_page(num_kv_heads, page_size, head_dim, dtype, target),
_compact_kv_copy(num_kv_heads, head_dim, dtype, target),
]:
mod = tvm.IRModule({"main": tir_func})
with target:
mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
f = tvm.tir.build(mod["main"], target=target)
builts.append(f.main)
(
ftranspose_append,
fcopy_cache,
fattn_prefill,
fattn_decode,
fattn_prefill_sliding_window,
fattn_decode_sliding_window,
fattn_prefill_ragged,
fattn_prefill_with_tree_mask,
fattn_prefill_with_tree_mask_paged_kv_cache,
fmerge_state,
fsplit_rotary,
fcopy_single_page,
fcompact_copy,
) = builts
def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create")
cache = fcreate(
tvm.runtime.ShapeTuple(
[
reserved_nseq,
maximum_total_seq_length,
prefill_chunk_size,
page_size,
int(support_sliding_window),
]
),
tvm.runtime.ShapeTuple([0, num_layers]),
num_qo_heads,
num_kv_heads,
head_dim,
head_dim, # v_head_dim
tvm.runtime.ShapeTuple([int(AttnKind.MHA) for _ in range(num_layers)]),
False, # enable_kv_transfer
rope_mode,
rope_scale,
rope_theta,
None, # rope_ext_factors
tvm.runtime.empty((), dtype, device=device),
ftranspose_append,
None, # f_transpose_append_mla
["tir", fattn_prefill_ragged],
["tir", fattn_prefill],
["tir", fattn_decode],
["tir", fattn_prefill_sliding_window],
["tir", fattn_decode_sliding_window],
["tir", fattn_prefill_with_tree_mask_paged_kv_cache],
["tir", fattn_prefill_with_tree_mask],
[], # f_mla_prefill
[fmerge_state],
fsplit_rotary,
fcopy_single_page,
fcopy_cache,
fcompact_copy,
)
return cache
@pytest.fixture(
params=itertools.chain(
itertools.product(
[64, 128],
["float32", "float16"],
[RopeMode.NORMAL],
[False],
),
itertools.product(
[128],
["float16"],
[RopeMode.NONE, RopeMode.INLINE],
[False, True],
),
)
)
def kv_cache_and_config(request):
global head_dim, sm_scale, dtype
head_dim, dtype, rope_mode, support_sliding_window = request.param
sm_scale = head_dim ** (-0.5)
set_global_func(head_dim, dtype)
return create_kv_cache(*request.param), rope_mode, support_sliding_window
def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v):
for seq_id in seq_ids:
keys_expected = expected_k[seq_id]
values_expected = expected_v[seq_id]
assert keys_expected.shape == values_expected.shape
seq_length = expected_k[seq_id].shape[1]
keys = tvm.runtime.empty(keys_expected.shape, dtype=dtype, device=device)
values = tvm.runtime.empty(values_expected.shape, dtype=dtype, device=device)
fdebug_get_kv(kv_cache, seq_id, 0, seq_length, keys, values)
torch.testing.assert_close(
torch.from_numpy(keys.numpy()).to(device_torch), keys_expected, rtol=1e-3, atol=1e-3
)
torch.testing.assert_close(
torch.from_numpy(values.numpy()).to(device_torch), values_expected, rtol=1e-3, atol=1e-3
)
def f_apply_rotary(x, offset, scale, theta, offset_list: Optional[List[int]] = None):
# x: (N, H, D)
assert len(x.shape) == 3
nfeat = x.shape[-1]
nfeat_half = x.shape[-1] // 2
x_dtype = x.dtype
x = x.to(torch.float32)
y = torch.cat([-x[:, :, nfeat_half:], x[:, :, :nfeat_half]], dim=-1)
inv_freq = scale / (
theta ** (torch.arange(0, nfeat, 2, device=device_torch, dtype=torch.float32) / nfeat)
)
t = (
torch.arange(offset, offset + x.shape[0], device=device_torch, dtype=inv_freq.dtype)
if offset_list is None
else (torch.tensor(offset_list, dtype=inv_freq.dtype, device=device_torch) + offset)
)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos_values = torch.cos(emb)
sin_values = torch.sin(emb)
return torch.einsum("ij,ikj->ikj", cos_values, x).to(x_dtype) + torch.einsum(
"ij,ikj->ikj", sin_values, y
).to(x_dtype)
def apply_attention(
kv_cache,
rope_mode: RopeMode,
batch: List[Tuple[Union[int, Tuple[int, int, int]], int]],
cached_k: Dict[int, torch.Tensor],
cached_v: Dict[int, torch.Tensor],
sliding_window_sizes: Optional[List[int]] = None,
attn_sink_sizes: Optional[List[int]] = None,
token_tree_parent_ptr_list: Optional[List[List[int]]] = None,
accepted_leaf_indices: Optional[List[int]] = None,
only_update_host=False,
skip_add_sequence=False,
) -> None:
seq_ids = []
append_lengths = []
for i, (seq_id, append_length) in enumerate(batch):
fork_parent_id = None
if isinstance(seq_id, tuple):
# Fork sequence
seq_id, fork_parent_id, fork_pos = seq_id
batch[i] = (seq_id, append_length)
seq_ids.append(seq_id)
append_lengths.append(append_length)
if fork_parent_id is not None:
assert fork_parent_id in cached_k
assert seq_id not in cached_k
if not only_update_host:
ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos)
if fork_pos == -1:
cached_k[seq_id] = cached_k[fork_parent_id]
cached_v[seq_id] = cached_v[fork_parent_id]
else:
cached_k[seq_id] = cached_k[fork_parent_id][::, :fork_pos]
cached_v[seq_id] = cached_v[fork_parent_id][::, :fork_pos]
elif seq_id not in cached_k:
if not only_update_host and not skip_add_sequence:
fadd_sequence(kv_cache, seq_id)
cached_k[seq_id] = torch.zeros(
(num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch
)
cached_v[seq_id] = torch.zeros(
(num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch
)
flattened_token_tree_parent_ptr = None
token_tree_node_depths_list: List[Optional[List[int]]] = [None for _ in batch]
if token_tree_parent_ptr_list:
assert len(token_tree_node_depths_list) == len(seq_ids)
if accepted_leaf_indices is not None:
assert len(accepted_leaf_indices) == len(seq_ids)
flattened_token_tree_parent_ptr = []
for i, (token_tree_parent_ptr, append_length) in enumerate(
zip(token_tree_parent_ptr_list, append_lengths)
):
assert len(token_tree_parent_ptr) >= append_length
# parent pointer for the last `append_length` nodes (the new tokens)
append_token_tree_parent_ptr = token_tree_parent_ptr[-append_length:]
flattened_token_tree_parent_ptr += append_token_tree_parent_ptr
token_tree_node_depths = []
for parent in token_tree_parent_ptr:
token_tree_node_depths.append(
0 if parent == -1 else token_tree_node_depths[parent] + 1
)
# depth of each node in the tree (this contains more than the last `append_length` nodes)
token_tree_node_depths_list[i] = token_tree_node_depths
if not only_update_host:
fbegin_forward(
kv_cache,
ShapeTuple(seq_ids),
ShapeTuple(append_lengths),
(
ShapeTuple(flattened_token_tree_parent_ptr)
if flattened_token_tree_parent_ptr is not None
else None
),
)
global_new_q = torch.zeros(
(num_layers, 0, num_qo_heads, head_dim), dtype=dtype_torch, device=device_torch
)
global_new_k = torch.zeros(
(num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch
)
global_new_v = torch.zeros(
(num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch
)
q_array = []
for i, (seq_id, append_length) in enumerate(batch):
new_q = torch.rand(
num_layers,
append_length,
num_qo_heads,
head_dim,
dtype=dtype_torch,
device=device_torch,
)
new_k = torch.rand(
num_layers,
append_length,
num_kv_heads,
head_dim,
dtype=dtype_torch,
device=device_torch,
)
new_v = torch.rand(
num_layers,
append_length,
num_kv_heads,
head_dim,
dtype=dtype_torch,
device=device_torch,
)
new_q = new_q * 2 - 1
new_k = new_k * 2 - 1
new_v = new_v * 2 - 1
q_array.append(new_q)
rope_offset = cached_k[seq_id].shape[1]
if token_tree_parent_ptr_list is not None:
prev_tree_size = len(token_tree_parent_ptr_list[i]) - append_length
assert prev_tree_size >= 0
rope_offset -= prev_tree_size
cached_k[seq_id] = torch.cat(
[
cached_k[seq_id],
torch.stack(
[
(
new_k[l]
if rope_mode != RopeMode.NORMAL
else f_apply_rotary(
new_k[l],
rope_offset,
rope_scale,
rope_theta,
(
token_tree_node_depths_list[i][-append_length:]
if token_tree_node_depths_list[i] is not None
else None
),
)
)
for l in range(num_layers)
],
dim=0,
),
],
dim=1,
)
cached_v[seq_id] = torch.cat([cached_v[seq_id], new_v], dim=1)
global_new_q = torch.cat([global_new_q, new_q], dim=1)
global_new_k = torch.cat([global_new_k, new_k], dim=1)
global_new_v = torch.cat([global_new_v, new_v], dim=1)
for layer_id in range(num_layers):
queries_np = global_new_q[layer_id]
keys_np = global_new_k[layer_id]
values_np = global_new_v[layer_id]
qkv = tvm.runtime.tensor(
torch.cat([queries_np, keys_np, values_np], dim=1).cpu().numpy(), device
)
outputs = tvm.runtime.empty(queries_np.shape, dtype, device=device)
if not only_update_host:
fattention_with_fuse_qkv(kv_cache, layer_id, sm_scale, qkv, outputs)
# Compute attention expected results.
outputs = torch.from_numpy(outputs.numpy()).unsqueeze(0).to(device_torch)
sum_length = 0
for i, (seq_id, append_length) in enumerate(batch):
assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= append_length
rope_offset = cached_k[seq_id].shape[1]
if token_tree_parent_ptr_list is not None:
rope_offset -= len(token_tree_parent_ptr_list[i])
else:
rope_offset -= append_length
q_seq = (
q_array[i][layer_id]
if rope_mode == RopeMode.NONE
else f_apply_rotary(
q_array[i][layer_id],
rope_offset,
rope_scale,
rope_theta,
(
token_tree_node_depths_list[i][-append_length:]
if token_tree_node_depths_list[i] is not None
else None
),
)
).permute(1, 0, 2)
k_seq = (
cached_k[seq_id][layer_id]
if rope_mode != RopeMode.INLINE
else f_apply_rotary(
cached_k[seq_id][layer_id],
0,
rope_scale,
rope_theta,
(
(
list(range(rope_offset))
+ [depth + rope_offset for depth in token_tree_node_depths_list[i]]
)
if token_tree_node_depths_list[i] is not None
else None
),
)
).permute(1, 2, 0)
v_seq = cached_v[seq_id][layer_id].permute(1, 0, 2)
k_seq = k_seq.repeat_interleave(num_qo_heads // num_kv_heads, dim=0)
v_seq = v_seq.repeat_interleave(num_qo_heads // num_kv_heads, dim=0)
softmax_input = (q_seq.to(torch.float32) @ k_seq.to(torch.float32)) / (head_dim**0.5)
softmax_shape = softmax_input.shape
assert softmax_shape[-2] == append_length
length_diff = softmax_shape[-1] - softmax_shape[-2]
assert length_diff >= 0
mask = torch.tril(
torch.full_like(softmax_input, torch.finfo(torch.float32).max), diagonal=length_diff
) + torch.triu(
torch.full_like(softmax_input, torch.finfo(torch.float32).min),
diagonal=length_diff + 1,
)
if token_tree_parent_ptr_list is not None:
tree_size = len(token_tree_parent_ptr_list[i])
tree_mask = torch.full(
(tree_size, tree_size),
torch.finfo(torch.float32).min,
dtype=torch.float32,
device=device_torch,
)
for i, parent in enumerate(token_tree_parent_ptr_list[i]):
if parent != -1:
tree_mask[i] = tree_mask[parent]
tree_mask[i, i] = torch.finfo(torch.float32).max
tree_mask = tree_mask.expand(num_qo_heads, *tree_mask.shape)
mask[:, :, -tree_size:] = tree_mask[:, -append_length:, :]
softmax_input = torch.minimum(softmax_input, mask)
results = torch.unsqueeze(
(
torch.nn.functional.softmax(softmax_input, dim=-1) @ v_seq.to(torch.float32)
).permute(1, 0, 2),
dim=0,
).to(dtype_torch)
if not only_update_host:
torch.testing.assert_close(
outputs[:, sum_length : sum_length + append_length, ...],
results,
rtol=1e-3,
atol=1e-3,
)
sum_length += append_length
if not only_update_host:
fend_forward(kv_cache)
if accepted_leaf_indices is not None:
seq_ids = [seq_id for seq_id, _ in batch]
if not only_update_host:
fcommit_accepted_token_tree_nodes(
kv_cache, ShapeTuple(seq_ids), ShapeTuple(accepted_leaf_indices)
)
for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate(
zip(accepted_leaf_indices, batch)
):
tree_path = []
node = accepted_leaf_idx
while node != -1:
tree_path.append(node)
node = token_tree_parent_ptr_list[i][node]
offset = cached_k[seq_id].shape[1] - append_length
length_to_pop = append_length - len(tree_path)
assert 0 <= length_to_pop <= append_length
for dst_pos, src_pos in enumerate(reversed(tree_path)):
if dst_pos == src_pos:
continue
cached_k[seq_id][:, offset + dst_pos, ...] = cached_k[seq_id][
:, offset + src_pos, ...
]
cached_v[seq_id][:, offset + dst_pos, ...] = cached_v[seq_id][
:, offset + src_pos, ...
]
if length_to_pop > 0:
cached_k[seq_id] = cached_k[seq_id][:, :-length_to_pop, ...]
cached_v[seq_id] = cached_v[seq_id][:, :-length_to_pop, ...]
for seq_id, _ in batch:
if sliding_window_sizes is not None and len(sliding_window_sizes) > seq_id:
assert len(sliding_window_sizes) > seq_id and len(attn_sink_sizes) > seq_id
sliding_window_size = sliding_window_sizes[seq_id]
attn_sink_size = attn_sink_sizes[seq_id]
if sliding_window_size == 0:
continue
if cached_k[seq_id].shape[1] > sliding_window_size:
# Apply sliding window and sink to cached kv.
length_to_slide = cached_k[seq_id].shape[1] - sliding_window_size
cached_k[seq_id] = torch.cat(
[
cached_k[seq_id][:, :attn_sink_size, ...],
cached_k[seq_id][:, attn_sink_size + length_to_slide :, ...],
],
dim=1,
)
cached_v[seq_id] = torch.cat(
[
cached_v[seq_id][:, :attn_sink_size, ...],
cached_v[seq_id][:, attn_sink_size + length_to_slide :, ...],
],
dim=1,
)
assert cached_k[seq_id].shape[1] == sliding_window_size
# Verify
if not only_update_host:
verify_cached_kv(kv_cache, seq_ids, cached_k, cached_v)
@pytest.mark.skip(reason="Require NVSHMEM")
def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config):
kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
if support_sliding_window and rope_mode == RopeMode.NORMAL:
# Normal RoPE mode under sliding window settings is not supported.
return
fclear(kv_cache)
# Prefill.
operation_seq = [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4, 19), (5, 20)]]
operation_seq += [[(6, 21), (7, 24)], [(2, 5), (4, 7), (8, 24)]]
operation_seq += [[(6, 13)], [(8, 19)], [(0, 1)], [(1, 3), (3, 8), (5, 12), (7, 11)]]
# Decode
operation_seq += [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]]
operation_seq += [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]]
operation_seq += [[(0, 1), (2, 1), (4, 1), (6, 1), (8, 1)]]
operation_seq += [[(4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]]
cached_k = {}
cached_v = {}
for batch in operation_seq:
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
@pytest.mark.skip(reason="Require NVSHMEM")
def test_paged_attention_kv_cache_transfer(kv_cache_and_config):
kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
if support_sliding_window:
# Normal RoPE mode under sliding window settings is not supported.
return
np.random.seed(0)
fclear(kv_cache)
# Prefill.
prefill_operation_seq = [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4, 19), (5, 20)]]
prefill_operation_seq += [[(6, 21), (7, 24)], [(2, 5), (4, 7), (8, 24)]]
prefill_operation_seq += [[(6, 13)], [(8, 19)], [(0, 1)], [(1, 3), (3, 8), (5, 12), (7, 11)]]
prefill_len = {i: 0 for i in range(9)}
for batch in prefill_operation_seq:
for seq_id, append_length in batch:
prefill_len[seq_id] += append_length
# Decode
decode_operation_seq = [
[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]
]
decode_operation_seq += [
[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]
]
decode_operation_seq += [[(0, 1), (2, 1), (4, 1), (6, 1), (8, 1)]]
decode_operation_seq += [[(4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]]
cached_k = {}
cached_v = {}
if rank == 0:
for seq_id, _ in prefill_len.items():
fadd_sequence(kv_cache, seq_id)
remote_pos_maps = None
remote_pos_maps = comm.bcast(remote_pos_maps, root=1)
comm.Barrier()
for seq_id in prefill_len.keys():
fdisagg_mark_send(kv_cache, seq_id, 0, ShapeTuple(remote_pos_maps[seq_id]), 1)
for batch in prefill_operation_seq:
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, skip_add_sequence=True)
device.sync()
comm.Barrier()
else:
remote_pos_maps = []
for seq_id, len in prefill_len.items():
fadd_sequence(kv_cache, seq_id)
compressed_pos_map = list(fdisagg_prepare_recv(kv_cache, seq_id, len))
remote_pos_maps.append(compressed_pos_map)
remote_pos_maps = comm.bcast(remote_pos_maps, root=1)
comm.Barrier()
for batch in prefill_operation_seq:
apply_attention(
kv_cache,
rope_mode,
batch,
cached_k,
cached_v,
only_update_host=True,
skip_add_sequence=True,
)
comm.Barrier()
for batch in decode_operation_seq:
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, skip_add_sequence=True)
def init_nvshmem(num_workers, pe_offset):
if rank == 0:
f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
else:
uid = None
uid = comm.bcast(uid, root=0)
init_func = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem")
init_func(uid, num_workers, pe_offset)
if __name__ == "__main__":
# To run this test, install mpi4py first, and then run
# mpirun -np 2 python tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
HEAD_DIMS = [128]
DTYPES = ["float16"]
ROPE_MODES = [RopeMode.NONE]
SUPPORT_SLIDING_WINDOW = [False]
init_nvshmem(2, rank)
for head_dim, dtype, rope_mode, support_sliding_window in itertools.product(
HEAD_DIMS, DTYPES, ROPE_MODES, SUPPORT_SLIDING_WINDOW
):
set_global_func(head_dim, dtype)
cache = create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window)
cache_and_config = (cache, rope_mode, support_sliding_window)
test_paged_attention_kv_cache_transfer(cache_and_config)