blob: 8dcd7bf6128951f74d7a3f4f616ce5c146f19517 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import random
import numpy as np
import tvm
import tvm.testing
import pytest
from tvm import relax
from tvm.contrib import utils
from typing import List
@pytest.mark.skip(reason="Requires FlashInfer enabled and proper setup")
def test_sampling():
def load_module(name: str, static_modules: List[tvm.runtime.Module]):
assert len(static_modules) > 0
if len(static_modules) == 1:
return static_modules[0]
static_mod = static_modules[0]
for mod in static_modules[1:]:
static_mod.import_module(mod)
temp = utils.tempdir()
mod_path = temp.relpath(f"{name}.so")
static_mod.export_library(mod_path)
return tvm.runtime.load_module(mod_path)
# Test configuration
batch_size = 10
vocab_size = 5
num_iterations = 1000
tol_atol = 0.02
tol_rtol = 0.05 # relative tolerance
# Probability tensor (each row sums to 1)
probs_np = np.array([[0.1, 0.2, 0.3, 0.2, 0.2] for _ in range(batch_size)], dtype="float32")
dev = tvm.cuda(0)
prob_tvm = tvm.runtime.tensor(probs_np, device=dev)
output_tvm = tvm.runtime.empty((batch_size,), "int32", device=dev)
device = tvm.cuda()
target = tvm.target.Target.from_device(device)
sampling_mod = load_module(
"flashinfer_sampling",
relax.backend.cuda.flashinfer.gen_sampling_module(
target=target,
),
)
sampling_func = sampling_mod["sampling_from_probs"]
counts = np.zeros((batch_size, vocab_size), dtype="int32")
for _ in range(num_iterations):
deterministic = False
# Generate seed and a random offset.
philox_seed = np.uint64(random.getrandbits(63))
philox_offset = np.uint64(random.getrandbits(63) % 1000)
# the kernel expects (probs, output, maybe_indices, deterministic, philox_seed, philox_offset, cuda_stream)
sampling_func(prob_tvm, output_tvm, None, deterministic, philox_seed, philox_offset, 0)
out = output_tvm.numpy()
for i in range(batch_size):
sampled_token = out[i]
counts[i, sampled_token] += 1
# Convert counts to frequencies.
frequencies = counts / float(num_iterations)
# For each row, check that the empirical frequency is close to the input probability.
for row in range(batch_size):
tvm.testing.assert_allclose(frequencies[row], probs_np[row], rtol=tol_rtol, atol=tol_atol)
if __name__ == "__main__":
# Run the test standalone (if not using pytest)
test_sampling()