| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """Test for FlashInfer GroupedGemm TVM integration""" |
| |
| import math |
| |
| import numpy as np |
| import pytest |
| import torch |
| |
| import tvm |
| import tvm.testing |
| from tvm import relax |
| |
| DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024 |
| fp8_dtype = "float8_e4m3fn" |
| |
| |
| ########################################### |
| ################# Helpers ################# |
| ########################################### |
| def has_flashinfer(): |
| """Check if FlashInfer is available""" |
| try: |
| from tvm.relax.backend.cuda import ( # pylint: disable=import-outside-toplevel |
| flashinfer, |
| ) |
| |
| return True |
| except ImportError: |
| return False |
| |
| |
| def has_cutlass(): |
| """Check if CUTLASS is available for SM90+ operations""" |
| if not tvm.get_global_func("device_api.cuda", True): |
| return False |
| try: |
| import pynvml # pylint: disable=import-outside-toplevel |
| |
| pynvml.nvmlInit() |
| handle = pynvml.nvmlDeviceGetHandleByIndex(0) |
| major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) |
| return major >= 9 # SM90+ |
| except: |
| return False |
| |
| |
| def calc_diff(x: np.ndarray, y: np.ndarray): |
| denominator = (x * x + y * y).sum() |
| sim = 2 * (x * y).sum() / denominator |
| return 1 - sim |
| |
| |
| def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode): |
| from einops import rearrange, reduce, repeat |
| |
| """ |
| Quantizes a 2D or 3D tensor to FP8. |
| |
| Args: |
| x (torch.Tensor): The 2D or 3D input tensor. |
| scale_shape (tuple): The shape of the scale tensor. |
| tile_shape (tuple): The shape of the tiles. |
| scale_major_mode (str): The tiling order, "K" for row-major like, |
| or another value for column-major like. |
| |
| Returns: |
| tuple: A tuple containing the quantized FP8 tensor and the |
| calculated float32 scales. |
| """ |
| # 1. Assertions and Initial Setup |
| ndim = x.ndim |
| assert ndim == len(scale_shape) == len(tile_shape) |
| |
| fp8_info = torch.finfo(torch.float8_e4m3fn) |
| fp8_amax = torch.tensor(fp8_info.max, device=x.device, dtype=torch.float32) |
| |
| # 2. Tiling and Scale Calculation |
| if ndim == 2: |
| s0, s1 = scale_shape |
| t0, t1 = tile_shape |
| if scale_major_mode == "K": |
| # Tile x and find the max absolute value in each tile |
| x_tiled = rearrange(x, "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) |
| abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clamp(1e-4) |
| x_scale = abs_max / fp8_amax |
| x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) |
| |
| # Broadcast scales back to the original tensor shape |
| scales_repeated = repeat(x_scale, "s0 s1 -> (s0 t0) (s1 t1)", t0=t0, t1=t1) |
| else: |
| # Handle column-major tiling |
| x_tiled = rearrange(x, "(s1 t0) (s0 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) |
| abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clamp(1e-4) |
| x_scale = abs_max / fp8_amax |
| x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) |
| |
| # Permute scale axes before repeating to match layout |
| scales_permuted = rearrange(x_scale, "s0 s1 -> s1 s0") |
| scales_repeated = repeat(scales_permuted, "s1 s0 -> (s1 t0) (s0 t1)", t0=t0, t1=t1) |
| |
| elif ndim == 3: |
| s0, s1, s2 = scale_shape |
| t0, t1, t2 = tile_shape |
| if scale_major_mode == "K": |
| # Tile x and find the max absolute value in each tile |
| x_tiled = rearrange( |
| x, "(s0 t0) (s1 t1) (s2 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2 |
| ) |
| abs_max = reduce(x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max").clamp(1e-4) |
| x_scale = abs_max / fp8_amax |
| x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) |
| |
| # Broadcast scales back to the original tensor shape |
| scales_repeated = repeat( |
| x_scale, "s0 s1 s2 -> (s0 t0) (s1 t1) (s2 t2)", t0=t0, t1=t1, t2=t2 |
| ) |
| else: |
| # Handle layout where the last two axes are swapped |
| x_tiled = rearrange( |
| x, "(s0 t0) (s2 t1) (s1 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2 |
| ) |
| abs_max = reduce(x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max").clamp(1e-4) |
| x_scale = abs_max / fp8_amax |
| x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) |
| # Permute scale axes before repeating to match layout |
| scales_permuted = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1") |
| scales_repeated = repeat( |
| scales_permuted, |
| "s0 s2 s1 -> (s0 t0) (s2 t1) (s1 t2)", |
| t0=t0, |
| t1=t1, |
| t2=t2, |
| ) |
| # 3. Final Quantization |
| # Divide the original tensor by the broadcasted scales |
| x_fp32 = x / (scales_repeated + 1e-8) |
| |
| # Convert the result to the target FP8 format |
| x_fp8 = x_fp32.to(torch.float8_e4m3fn) |
| |
| return x_fp8, x_scale |
| |
| |
| def dequantize_fp8(x, x_scale, scale_major_mode): |
| from einops import rearrange |
| |
| """ |
| Quantizes a 2D or 3D tensor to FP8. |
| |
| Args: |
| x (torch.Tensor): The 2D or 3D input tensor. |
| scale_shape (tuple): The shape of the scale tensor. |
| tile_shape (tuple): The shape of the tiles. |
| scale_major_mode (str): The tiling order, "K" for row-major like, |
| or another value for column-major like. |
| |
| Returns: |
| tuple: A tuple containing the quantized FP8 tensor and the |
| calculated float32 scales. |
| """ |
| # 1. Assertions and Initial Setup |
| ndim = x.ndim |
| assert ndim == len(x_scale.shape) |
| |
| # 2. Tiling and Scale Calculation |
| if ndim == 2: |
| if scale_major_mode == "K": |
| s0, s1 = x_scale.shape |
| else: |
| s1, s0 = x_scale.shape |
| x = rearrange(x.to(torch.float32), "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) |
| if scale_major_mode == "K": |
| x_scale = rearrange(x_scale, "s0 s1 -> s0 s1 1 1") |
| else: |
| x_scale = rearrange(x_scale, "s0 s1 -> s1 s0 1 1") |
| out = rearrange(x * x_scale, "s0 s1 t0 t1 -> (s0 t0) (s1 t1)") |
| elif ndim == 3: |
| if scale_major_mode == "K": |
| s0, s1, s2 = x_scale.shape |
| else: |
| s0, s2, s1 = x_scale.shape |
| x = rearrange( |
| x.to(torch.float32), |
| "(s0 t0) (s1 t1) (s2 t2)-> s0 s1 s2 t0 t1 t2", |
| s0=s0, |
| s1=s1, |
| s2=s2, |
| ) |
| if scale_major_mode == "K": |
| x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s1 s2 1 1 1") |
| else: |
| x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1 1 1 1") |
| out = rearrange(x * x_scale, "s0 s1 s2 t0 t1 t2 -> (s0 t0) (s1 t1) (s2 t2)") |
| |
| return out |
| |
| |
| ########################################### |
| ########### Refernce generation ########### |
| ########################################### |
| def compute_reference_grouped_gemm( |
| a_fp32: torch.Tensor, # (total_m, k) |
| b_fp32: torch.Tensor, # (batch_size, n, k) |
| m_indptr: torch.Tensor, |
| dtype_out: str, # (total_m, n) |
| ): |
| """Compute reference result using PyTorch operations""" |
| """Compute reference result using original FP32 tensors""" |
| |
| total_m, k = a_fp32.shape |
| batch_size, n, k2 = b_fp32.shape |
| assert k == k2 |
| |
| # Perform grouped GEMM computation directly on original FP32 data |
| results = [] |
| |
| for i in range(batch_size): |
| start_m = m_indptr[i].item() |
| end_m = m_indptr[i + 1].item() |
| |
| # Extract group's portion of A |
| a_group = a_fp32[start_m:end_m, :] # [m_sizes[i], k] |
| b_group = b_fp32[i] |
| |
| # Multiply with shared B matrix |
| result_group = torch.mm(a_group, b_group.T) # [m_sizes[i], n] |
| results.append(result_group) |
| |
| result_fp32 = torch.cat(results, dim=0) |
| |
| # Convert to output dtype |
| if dtype_out == "bfloat16": |
| result = result_fp32.to(torch.bfloat16) |
| elif dtype_out == "float16": |
| result = result_fp32.to(torch.float16) |
| else: |
| result = result_fp32 |
| |
| return result |
| |
| |
| ########################################### |
| ########### Test data generation ########## |
| ########################################### |
| def generate_test_data( |
| m_sizes: list, |
| batch_size: int, |
| n: int, |
| k: int, |
| dtype_a: str, |
| dtype_b: str, |
| dtype_out: str, |
| scale_granularity_m: int, |
| scale_granularity_n: int, |
| scale_granularity_k: int, |
| scale_major_mode: str, |
| device: tvm.runtime.Device, |
| ): |
| """Generate test data for grouped GEMM operations""" |
| assert batch_size == len( |
| m_sizes |
| ), f"batch_size ({batch_size}) must equal len(m_sizes) ({len(m_sizes)})" |
| |
| # print(f"Device object: {device}") |
| torch_device = torch.device(f"cuda:{device.index}") |
| |
| cum_m = [0] + list(np.cumsum(m_sizes)) |
| total_m = cum_m[-1] |
| |
| # Generate input matrices A and B (where we assert of form fp8) random data in fp32 first, then convert |
| assert dtype_a == "float8_e4m3fn" |
| a_fp32 = torch.randn(total_m, k, device=torch_device, dtype=torch.float32) |
| |
| assert dtype_b == "float8_e4m3fn" |
| b_fp32 = torch.randn(batch_size, n, k, device=torch_device, dtype=torch.float32) / math.sqrt(k) |
| |
| if scale_major_mode == "K": # K mode: |
| scale_a_shape = (total_m // scale_granularity_m, k // scale_granularity_k) |
| scale_b_shape = (batch_size, n // scale_granularity_n, k // scale_granularity_k) |
| |
| else: # MN mode |
| scale_a_shape = (k // scale_granularity_k, total_m // scale_granularity_m) |
| scale_b_shape = (batch_size, k // scale_granularity_k, n // scale_granularity_n) |
| |
| tile_a_shape = (scale_granularity_m, scale_granularity_k) |
| tile_b_shape = (1, scale_granularity_n, scale_granularity_k) |
| |
| # quantize A, B |
| a_quantized, scale_a = quantize_fp8(a_fp32, scale_a_shape, tile_a_shape, scale_major_mode) |
| b_quantized, scale_b = quantize_fp8(b_fp32, scale_b_shape, tile_b_shape, scale_major_mode) |
| |
| if dtype_a == "float8_e4m3fn": |
| a_tvm = tvm.runtime.tensor( |
| a_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device |
| ) |
| else: |
| a_tvm = tvm.runtime.from_dlpack(a_quantized) |
| |
| if dtype_b == "float8_e4m3fn": |
| b_tvm = tvm.runtime.tensor( |
| b_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device |
| ) |
| else: |
| b_tvm = tvm.runtime.from_dlpack(b_quantized) |
| |
| scale_a_tvm = tvm.runtime.from_dlpack(scale_a) |
| scale_b_tvm = tvm.runtime.from_dlpack(scale_b) |
| |
| # Create m_indptr for grouped operation |
| m_indptr = torch.tensor(cum_m, device=torch_device, dtype=torch.int32) |
| m_indptr_tvm = tvm.runtime.tensor(m_indptr.cpu().numpy(), device) |
| |
| return { |
| "a": a_tvm, |
| "b": b_tvm, |
| "torch_a": a_fp32, |
| "torch_b": b_fp32, |
| "scale_a": scale_a_tvm, |
| "scale_b": scale_b_tvm, |
| "m_indptr": m_indptr_tvm, |
| "m_sizes": m_sizes, |
| "n": n, |
| "k": k, |
| "total_m": total_m, |
| "torch_scale_a": scale_a, |
| "torch_scale_b": scale_b, |
| "torch_m_indptr": m_indptr, |
| } |
| |
| |
| ########################################### |
| ############### Test driver ############### |
| ########################################### |
| @pytest.mark.skipif(not has_flashinfer(), reason="FlashInfer not available") |
| @pytest.mark.skipif(not has_cutlass(), reason="CUTLASS SM90+ not available") |
| @pytest.mark.parametrize( |
| "dtype_a,dtype_b,dtype_out", |
| [ |
| ("float8_e4m3fn", "float8_e4m3fn", "bfloat16"), |
| ("float8_e4m3fn", "float8_e4m3fn", "float16"), |
| ], |
| ) |
| @pytest.mark.parametrize( |
| "scale_granularity_m,scale_granularity_n,scale_granularity_k", |
| [ |
| (1, 128, 128), # Row-wise A, block-wise B |
| ], |
| ) |
| @pytest.mark.parametrize("scale_major_mode", ["K", "MN"]) |
| @pytest.mark.parametrize("mma_sm", [1, 2]) |
| @pytest.mark.parametrize( |
| "test_case", |
| [ |
| {"batch_size": 4, "m_sizes": [128, 256, 192, 320], "n": 512, "k": 1024}, |
| {"batch_size": 2, "m_sizes": [64, 128], "n": 256, "k": 512}, |
| {"batch_size": 3, "m_sizes": [256, 256, 128], "n": 768, "k": 768}, |
| {"batch_size": 2, "m_sizes": [20, 36], "n": 768, "k": 768}, |
| ], |
| ) |
| def test_grouped_gemm_correctness( |
| dtype_a, |
| dtype_b, |
| dtype_out, |
| scale_granularity_m, |
| scale_granularity_n, |
| scale_granularity_k, |
| scale_major_mode, |
| mma_sm, |
| test_case, |
| ): |
| """Test correctness of GroupedGemm operations""" |
| device = tvm.cuda(0) |
| target = tvm.target.Target.from_device(device) |
| |
| # Generate the module |
| mod = relax.backend.cuda.flashinfer.gen_grouped_gemm_module(target=target)[0] |
| |
| # Load the module |
| grouped_gemm_fn = mod["group_gemm_fp8_nt_groupwise"] |
| |
| # Generate test data |
| test_data = generate_test_data( |
| batch_size=test_case["batch_size"], |
| m_sizes=test_case["m_sizes"], |
| n=test_case["n"], |
| k=test_case["k"], |
| dtype_a=dtype_a, |
| dtype_b=dtype_b, |
| dtype_out=dtype_out, |
| scale_granularity_m=scale_granularity_m, |
| scale_granularity_n=scale_granularity_n, |
| scale_granularity_k=scale_granularity_k, |
| scale_major_mode=scale_major_mode, |
| device=device, |
| ) |
| |
| # Prepare output buffer |
| output_shape = (test_data["total_m"], test_data["n"]) |
| if dtype_out == "bfloat16": |
| output = tvm.runtime.empty(output_shape, dtype="bfloat16", device=device) |
| elif dtype_out == "float16": |
| output = tvm.runtime.empty(output_shape, dtype="float16", device=device) |
| else: |
| output = tvm.runtime.empty(output_shape, dtype="float32", device=device) |
| |
| # Create workspace buffers (required by the interface) |
| int_workspace = tvm.runtime.empty((DEFAULT_WORKSPACE_SIZE,), dtype="int32", device=device) |
| float_workspace = tvm.runtime.empty((DEFAULT_WORKSPACE_SIZE,), dtype="float32", device=device) |
| |
| grouped_gemm_fn( |
| int_workspace, # int_workspace_buffer |
| float_workspace, # float_workspace_buffer |
| test_data["a"], # A |
| test_data["b"], # B |
| test_data["scale_a"], # SFA |
| test_data["scale_b"], # SFB |
| output, # D |
| test_data["m_indptr"], # m_indptr |
| test_data["n"], # n (scalar) |
| test_data["k"], # k (scalar) |
| scale_granularity_m, |
| scale_granularity_n, |
| scale_granularity_k, |
| scale_major_mode, |
| mma_sm, |
| ) |
| |
| # Compute reference result |
| reference = compute_reference_grouped_gemm( |
| test_data["torch_a"], |
| test_data["torch_b"], |
| test_data["torch_m_indptr"], |
| dtype_out, |
| ) |
| |
| # Convert TVM output to PyTorch for comparison |
| output_torch = torch.as_tensor(output, device=test_data["torch_a"].device) |
| output_torch |
| |
| # Compare results with appropriate tolerance |
| if dtype_out == "bfloat16": |
| rtol, atol = 1e-2, 1e-2 |
| elif dtype_out == "float16": |
| rtol, atol = 1e-3, 1e-3 |
| else: |
| rtol, atol = 1e-4, 1e-4 |
| |
| # Check shapes match |
| assert ( |
| output_torch.shape == reference.shape |
| ), f"Shape mismatch: got {output_torch.shape}, expected {reference.shape}" |
| |
| diff = calc_diff(output_torch.cpu().double().numpy(), reference.cpu().double().numpy()) |
| assert diff < 1e-3, f"diff too large {diff}" |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |