blob: da6fdacebdbd29d96f7964ac8321e9c59b4f415c [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test for FlashInfer GroupedGemm TVM integration"""
import math
import numpy as np
import pytest
import torch
import tvm
import tvm.testing
from tvm import relax
DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024
fp8_dtype = "float8_e4m3fn"
###########################################
################# Helpers #################
###########################################
def has_flashinfer():
"""Check if FlashInfer is available"""
try:
from tvm.relax.backend.cuda import ( # pylint: disable=import-outside-toplevel
flashinfer,
)
return True
except ImportError:
return False
def has_cutlass():
"""Check if CUTLASS is available for SM90+ operations"""
if not tvm.get_global_func("device_api.cuda", True):
return False
try:
import pynvml # pylint: disable=import-outside-toplevel
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
return major >= 9 # SM90+
except:
return False
def calc_diff(x: np.ndarray, y: np.ndarray):
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode):
from einops import rearrange, reduce, repeat
"""
Quantizes a 2D or 3D tensor to FP8.
Args:
x (torch.Tensor): The 2D or 3D input tensor.
scale_shape (tuple): The shape of the scale tensor.
tile_shape (tuple): The shape of the tiles.
scale_major_mode (str): The tiling order, "K" for row-major like,
or another value for column-major like.
Returns:
tuple: A tuple containing the quantized FP8 tensor and the
calculated float32 scales.
"""
# 1. Assertions and Initial Setup
ndim = x.ndim
assert ndim == len(scale_shape) == len(tile_shape)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_amax = torch.tensor(fp8_info.max, device=x.device, dtype=torch.float32)
# 2. Tiling and Scale Calculation
if ndim == 2:
s0, s1 = scale_shape
t0, t1 = tile_shape
if scale_major_mode == "K":
# Tile x and find the max absolute value in each tile
x_tiled = rearrange(x, "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1)
abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clamp(1e-4)
x_scale = abs_max / fp8_amax
x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs())))
# Broadcast scales back to the original tensor shape
scales_repeated = repeat(x_scale, "s0 s1 -> (s0 t0) (s1 t1)", t0=t0, t1=t1)
else:
# Handle column-major tiling
x_tiled = rearrange(x, "(s1 t0) (s0 t1) -> s0 s1 t0 t1", s0=s0, s1=s1)
abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clamp(1e-4)
x_scale = abs_max / fp8_amax
x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs())))
# Permute scale axes before repeating to match layout
scales_permuted = rearrange(x_scale, "s0 s1 -> s1 s0")
scales_repeated = repeat(scales_permuted, "s1 s0 -> (s1 t0) (s0 t1)", t0=t0, t1=t1)
elif ndim == 3:
s0, s1, s2 = scale_shape
t0, t1, t2 = tile_shape
if scale_major_mode == "K":
# Tile x and find the max absolute value in each tile
x_tiled = rearrange(
x, "(s0 t0) (s1 t1) (s2 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2
)
abs_max = reduce(x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max").clamp(1e-4)
x_scale = abs_max / fp8_amax
x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs())))
# Broadcast scales back to the original tensor shape
scales_repeated = repeat(
x_scale, "s0 s1 s2 -> (s0 t0) (s1 t1) (s2 t2)", t0=t0, t1=t1, t2=t2
)
else:
# Handle layout where the last two axes are swapped
x_tiled = rearrange(
x, "(s0 t0) (s2 t1) (s1 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2
)
abs_max = reduce(x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max").clamp(1e-4)
x_scale = abs_max / fp8_amax
x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs())))
# Permute scale axes before repeating to match layout
scales_permuted = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1")
scales_repeated = repeat(
scales_permuted,
"s0 s2 s1 -> (s0 t0) (s2 t1) (s1 t2)",
t0=t0,
t1=t1,
t2=t2,
)
# 3. Final Quantization
# Divide the original tensor by the broadcasted scales
x_fp32 = x / (scales_repeated + 1e-8)
# Convert the result to the target FP8 format
x_fp8 = x_fp32.to(torch.float8_e4m3fn)
return x_fp8, x_scale
def dequantize_fp8(x, x_scale, scale_major_mode):
from einops import rearrange
"""
Quantizes a 2D or 3D tensor to FP8.
Args:
x (torch.Tensor): The 2D or 3D input tensor.
scale_shape (tuple): The shape of the scale tensor.
tile_shape (tuple): The shape of the tiles.
scale_major_mode (str): The tiling order, "K" for row-major like,
or another value for column-major like.
Returns:
tuple: A tuple containing the quantized FP8 tensor and the
calculated float32 scales.
"""
# 1. Assertions and Initial Setup
ndim = x.ndim
assert ndim == len(x_scale.shape)
# 2. Tiling and Scale Calculation
if ndim == 2:
if scale_major_mode == "K":
s0, s1 = x_scale.shape
else:
s1, s0 = x_scale.shape
x = rearrange(x.to(torch.float32), "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1)
if scale_major_mode == "K":
x_scale = rearrange(x_scale, "s0 s1 -> s0 s1 1 1")
else:
x_scale = rearrange(x_scale, "s0 s1 -> s1 s0 1 1")
out = rearrange(x * x_scale, "s0 s1 t0 t1 -> (s0 t0) (s1 t1)")
elif ndim == 3:
if scale_major_mode == "K":
s0, s1, s2 = x_scale.shape
else:
s0, s2, s1 = x_scale.shape
x = rearrange(
x.to(torch.float32),
"(s0 t0) (s1 t1) (s2 t2)-> s0 s1 s2 t0 t1 t2",
s0=s0,
s1=s1,
s2=s2,
)
if scale_major_mode == "K":
x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s1 s2 1 1 1")
else:
x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1 1 1 1")
out = rearrange(x * x_scale, "s0 s1 s2 t0 t1 t2 -> (s0 t0) (s1 t1) (s2 t2)")
return out
###########################################
########### Refernce generation ###########
###########################################
def compute_reference_grouped_gemm(
a_fp32: torch.Tensor, # (total_m, k)
b_fp32: torch.Tensor, # (batch_size, n, k)
m_indptr: torch.Tensor,
dtype_out: str, # (total_m, n)
):
"""Compute reference result using PyTorch operations"""
"""Compute reference result using original FP32 tensors"""
total_m, k = a_fp32.shape
batch_size, n, k2 = b_fp32.shape
assert k == k2
# Perform grouped GEMM computation directly on original FP32 data
results = []
for i in range(batch_size):
start_m = m_indptr[i].item()
end_m = m_indptr[i + 1].item()
# Extract group's portion of A
a_group = a_fp32[start_m:end_m, :] # [m_sizes[i], k]
b_group = b_fp32[i]
# Multiply with shared B matrix
result_group = torch.mm(a_group, b_group.T) # [m_sizes[i], n]
results.append(result_group)
result_fp32 = torch.cat(results, dim=0)
# Convert to output dtype
if dtype_out == "bfloat16":
result = result_fp32.to(torch.bfloat16)
elif dtype_out == "float16":
result = result_fp32.to(torch.float16)
else:
result = result_fp32
return result
###########################################
########### Test data generation ##########
###########################################
def generate_test_data(
m_sizes: list,
batch_size: int,
n: int,
k: int,
dtype_a: str,
dtype_b: str,
dtype_out: str,
scale_granularity_m: int,
scale_granularity_n: int,
scale_granularity_k: int,
scale_major_mode: str,
device: tvm.runtime.Device,
):
"""Generate test data for grouped GEMM operations"""
assert batch_size == len(
m_sizes
), f"batch_size ({batch_size}) must equal len(m_sizes) ({len(m_sizes)})"
# print(f"Device object: {device}")
torch_device = torch.device(f"cuda:{device.index}")
cum_m = [0] + list(np.cumsum(m_sizes))
total_m = cum_m[-1]
# Generate input matrices A and B (where we assert of form fp8) random data in fp32 first, then convert
assert dtype_a == "float8_e4m3fn"
a_fp32 = torch.randn(total_m, k, device=torch_device, dtype=torch.float32)
assert dtype_b == "float8_e4m3fn"
b_fp32 = torch.randn(batch_size, n, k, device=torch_device, dtype=torch.float32) / math.sqrt(k)
if scale_major_mode == "K": # K mode:
scale_a_shape = (total_m // scale_granularity_m, k // scale_granularity_k)
scale_b_shape = (batch_size, n // scale_granularity_n, k // scale_granularity_k)
else: # MN mode
scale_a_shape = (k // scale_granularity_k, total_m // scale_granularity_m)
scale_b_shape = (batch_size, k // scale_granularity_k, n // scale_granularity_n)
tile_a_shape = (scale_granularity_m, scale_granularity_k)
tile_b_shape = (1, scale_granularity_n, scale_granularity_k)
# quantize A, B
a_quantized, scale_a = quantize_fp8(a_fp32, scale_a_shape, tile_a_shape, scale_major_mode)
b_quantized, scale_b = quantize_fp8(b_fp32, scale_b_shape, tile_b_shape, scale_major_mode)
if dtype_a == "float8_e4m3fn":
a_tvm = tvm.runtime.tensor(
a_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device
)
else:
a_tvm = tvm.runtime.from_dlpack(a_quantized)
if dtype_b == "float8_e4m3fn":
b_tvm = tvm.runtime.tensor(
b_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device
)
else:
b_tvm = tvm.runtime.from_dlpack(b_quantized)
scale_a_tvm = tvm.runtime.from_dlpack(scale_a)
scale_b_tvm = tvm.runtime.from_dlpack(scale_b)
# Create m_indptr for grouped operation
m_indptr = torch.tensor(cum_m, device=torch_device, dtype=torch.int32)
m_indptr_tvm = tvm.runtime.tensor(m_indptr.cpu().numpy(), device)
return {
"a": a_tvm,
"b": b_tvm,
"torch_a": a_fp32,
"torch_b": b_fp32,
"scale_a": scale_a_tvm,
"scale_b": scale_b_tvm,
"m_indptr": m_indptr_tvm,
"m_sizes": m_sizes,
"n": n,
"k": k,
"total_m": total_m,
"torch_scale_a": scale_a,
"torch_scale_b": scale_b,
"torch_m_indptr": m_indptr,
}
###########################################
############### Test driver ###############
###########################################
@pytest.mark.skipif(not has_flashinfer(), reason="FlashInfer not available")
@pytest.mark.skipif(not has_cutlass(), reason="CUTLASS SM90+ not available")
@pytest.mark.parametrize(
"dtype_a,dtype_b,dtype_out",
[
("float8_e4m3fn", "float8_e4m3fn", "bfloat16"),
("float8_e4m3fn", "float8_e4m3fn", "float16"),
],
)
@pytest.mark.parametrize(
"scale_granularity_m,scale_granularity_n,scale_granularity_k",
[
(1, 128, 128), # Row-wise A, block-wise B
],
)
@pytest.mark.parametrize("scale_major_mode", ["K", "MN"])
@pytest.mark.parametrize("mma_sm", [1, 2])
@pytest.mark.parametrize(
"test_case",
[
{"batch_size": 4, "m_sizes": [128, 256, 192, 320], "n": 512, "k": 1024},
{"batch_size": 2, "m_sizes": [64, 128], "n": 256, "k": 512},
{"batch_size": 3, "m_sizes": [256, 256, 128], "n": 768, "k": 768},
{"batch_size": 2, "m_sizes": [20, 36], "n": 768, "k": 768},
],
)
def test_grouped_gemm_correctness(
dtype_a,
dtype_b,
dtype_out,
scale_granularity_m,
scale_granularity_n,
scale_granularity_k,
scale_major_mode,
mma_sm,
test_case,
):
"""Test correctness of GroupedGemm operations"""
device = tvm.cuda(0)
target = tvm.target.Target.from_device(device)
# Generate the module
mod = relax.backend.cuda.flashinfer.gen_grouped_gemm_module(target=target)[0]
# Load the module
grouped_gemm_fn = mod["group_gemm_fp8_nt_groupwise"]
# Generate test data
test_data = generate_test_data(
batch_size=test_case["batch_size"],
m_sizes=test_case["m_sizes"],
n=test_case["n"],
k=test_case["k"],
dtype_a=dtype_a,
dtype_b=dtype_b,
dtype_out=dtype_out,
scale_granularity_m=scale_granularity_m,
scale_granularity_n=scale_granularity_n,
scale_granularity_k=scale_granularity_k,
scale_major_mode=scale_major_mode,
device=device,
)
# Prepare output buffer
output_shape = (test_data["total_m"], test_data["n"])
if dtype_out == "bfloat16":
output = tvm.runtime.empty(output_shape, dtype="bfloat16", device=device)
elif dtype_out == "float16":
output = tvm.runtime.empty(output_shape, dtype="float16", device=device)
else:
output = tvm.runtime.empty(output_shape, dtype="float32", device=device)
# Create workspace buffers (required by the interface)
int_workspace = tvm.runtime.empty((DEFAULT_WORKSPACE_SIZE,), dtype="int32", device=device)
float_workspace = tvm.runtime.empty((DEFAULT_WORKSPACE_SIZE,), dtype="float32", device=device)
grouped_gemm_fn(
int_workspace, # int_workspace_buffer
float_workspace, # float_workspace_buffer
test_data["a"], # A
test_data["b"], # B
test_data["scale_a"], # SFA
test_data["scale_b"], # SFB
output, # D
test_data["m_indptr"], # m_indptr
test_data["n"], # n (scalar)
test_data["k"], # k (scalar)
scale_granularity_m,
scale_granularity_n,
scale_granularity_k,
scale_major_mode,
mma_sm,
)
# Compute reference result
reference = compute_reference_grouped_gemm(
test_data["torch_a"],
test_data["torch_b"],
test_data["torch_m_indptr"],
dtype_out,
)
# Convert TVM output to PyTorch for comparison
output_torch = torch.as_tensor(output, device=test_data["torch_a"].device)
output_torch
# Compare results with appropriate tolerance
if dtype_out == "bfloat16":
rtol, atol = 1e-2, 1e-2
elif dtype_out == "float16":
rtol, atol = 1e-3, 1e-3
else:
rtol, atol = 1e-4, 1e-4
# Check shapes match
assert (
output_torch.shape == reference.shape
), f"Shape mismatch: got {output_torch.shape}, expected {reference.shape}"
diff = calc_diff(output_torch.cpu().double().numpy(), reference.cpu().double().numpy())
assert diff < 1e-3, f"diff too large {diff}"
if __name__ == "__main__":
tvm.testing.main()