blob: 951085e8530c4f31af9d28546a53fc513cdfed5d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Tuple
import ml_dtypes
import numpy as np
import tvm
import tvm.testing
from tvm.contrib.pickle_memoize import memoize
def get_random_tensor(shape, dtype):
if dtype == "int8":
return np.random.randint(-128, 128, shape).astype(dtype)
elif dtype == "uint8":
return np.random.randint(0, 256, shape).astype(dtype)
return np.random.uniform(-1, 1, shape).astype(dtype)
def verify_group_gemm(
func_name, M, N, K, num_groups, x_dtype, weight_dtype, out_dtype, use_scale, rtol, atol
):
group_gemm_func = tvm.get_global_func(func_name, allow_missing=True)
if group_gemm_func is None:
print(f"Skipped as {func_name} is not available")
return
@memoize("tvm.contrib.cutlass.test_group_gemm_sm90")
def get_ref_data():
assert M % num_groups == 0
M_per_group = M // num_groups
a_np = get_random_tensor((M, K), x_dtype)
b_np = get_random_tensor((num_groups, N, K), weight_dtype)
indptr_np = np.arange(1, num_groups + 1).astype("int64") * M_per_group
c_np = np.concatenate(
[a_np[i * M_per_group : (i + 1) * M_per_group] @ b_np[i].T for i in range(num_groups)],
axis=0,
)
return a_np, b_np, indptr_np, c_np
def to_numpy_dtype(dtype):
mapping = {"float8_e5m2": ml_dtypes.float8_e5m2, "float8_e4m3fn": ml_dtypes.float8_e4m3fn}
return mapping.get(dtype, dtype)
a_np, b_np, indptr_np, c_np = get_ref_data()
dev = tvm.cuda(0)
a_nd = tvm.runtime.tensor(a_np.astype(to_numpy_dtype(x_dtype)), device=dev)
b_nd = tvm.runtime.tensor(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev)
c_nd = tvm.runtime.empty(c_np.shape, dtype=out_dtype, device=dev)
indptr_nd = tvm.runtime.tensor(indptr_np, device=dev)
workspace = tvm.runtime.empty((4096 * 1024,), dtype="uint8", device=dev)
if use_scale:
scale = tvm.runtime.tensor(np.array([1.0], dtype="float32"), device=dev)
group_gemm_func(a_nd, b_nd, indptr_nd, workspace, scale, c_nd)
else:
group_gemm_func(a_nd, b_nd, indptr_nd, workspace, c_nd)
tvm.testing.assert_allclose(c_nd.numpy(), c_np, rtol=rtol, atol=atol)
@tvm.testing.requires_cutlass
@tvm.testing.requires_cuda_compute_version(9)
def test_group_gemm_sm90():
verify_group_gemm(
"cutlass.group_gemm",
8,
128,
128,
4,
"float16",
"float16",
"float16",
False,
rtol=1e-3,
atol=1e-3,
)
verify_group_gemm(
"cutlass.group_gemm_e5m2_e5m2_fp16",
8,
16,
16,
4,
"float8_e5m2",
"float8_e5m2",
"float16",
True,
rtol=1e-1,
atol=1,
)
verify_group_gemm(
"cutlass.group_gemm_e4m3_e4m3_fp16",
8,
16,
16,
4,
"float8_e4m3fn",
"float8_e4m3fn",
"float16",
True,
rtol=1e-1,
atol=1,
)
@tvm.testing.requires_cutlass
@tvm.testing.requires_cuda_compute_version(10)
def test_group_gemm_sm100():
verify_group_gemm(
"cutlass.group_gemm",
8,
128,
128,
4,
"bfloat16",
"bfloat16",
"bfloat16",
False,
rtol=1e-2,
atol=1e-3,
)
def rowwise_quant_fp8_e4m3(shape: Tuple[int, int], block_size: Tuple[int, int], dtype: str):
x_full_np = (np.random.rand(*shape) * 2 - 1).astype(dtype)
x_scale_shape = (
*shape[:-1],
(shape[-1] + block_size[1] - 1) // block_size[1],
)
# For each (block_size[1]) block, compute the max abs value of `w_full_np`
x_max_abs_np = np.zeros(x_scale_shape, dtype="float32")
for i in range(x_scale_shape[-1]):
x_max_abs_np[..., i] = np.max(
np.abs(x_full_np[..., i * block_size[1] : min((i + 1) * block_size[1], shape[-1])]),
axis=-1,
)[0]
# Scale is the `x_max_abs_np` divided by the max value of quant_dtype in ml_dtypes
fp8_max = float(ml_dtypes.finfo("float8_e4m3fn").max)
x_scale_np = x_max_abs_np / fp8_max
# `x_np` is the `x_full_np` divided by the `x_scale_np` (with block awareness),
# clamped to (-fp8_max, fp8_max), and cast to `quant_dtype`
x_np = np.zeros_like(x_full_np, dtype="float8_e4m3fn")
for i in range(x_scale_shape[-1]):
x_np[..., i * block_size[1] : min((i + 1) * block_size[1], shape[-1])] = np.clip(
x_full_np[..., i * block_size[1] : min((i + 1) * block_size[1], shape[-1])]
/ x_scale_np[..., i : i + 1],
-fp8_max,
fp8_max,
)
x_scale_np = np.random.rand(*x_scale_np.shape).astype("float32") / fp8_max
for i in range(x_scale_shape[-1]):
x_full_np[..., i * block_size[1] : min((i + 1) * block_size[1], shape[-1])] = (
x_np[..., i * block_size[1] : min((i + 1) * block_size[1], shape[-1])].astype(
x_scale_np.dtype
)
* x_scale_np[..., i : i + 1]
)
return x_np, x_scale_np
def blockwise_quant_fp8_e4m3(shape: Tuple[int, int], block_size: Tuple[int, int], dtype: str):
w_full_np = (np.random.rand(*shape) * 2 - 1).astype(dtype)
w_scale_shape = (
*shape[:-2],
(shape[-2] + block_size[0] - 1) // block_size[0],
(shape[-1] + block_size[1] - 1) // block_size[1],
)
# For each (block_size[0], block_size[1]) block, compute the max abs value of `w_full_np`
w_max_abs_np = np.zeros(w_scale_shape, dtype="float32")
for i in range(w_scale_shape[-2]):
for j in range(w_scale_shape[-1]):
block_shape = (
*shape[:-2],
min(block_size[0], shape[-2] - i * block_size[0]),
min(block_size[1], shape[-1] - j * block_size[1]),
)
w_max_abs_np[..., i, j] = np.max(
np.abs(
w_full_np[
...,
i * block_size[0] : min((i + 1) * block_size[0], shape[-2]),
j * block_size[1] : min((j + 1) * block_size[1], shape[-1]),
]
).reshape(*shape[:-2], block_shape[-2] * block_shape[-1]),
axis=-1,
)
# Scale is the `w_max_abs_np` divided by the max value of quant_dtype in ml_dtypes
fp8_max = float(ml_dtypes.finfo("float8_e4m3fn").max)
w_scale_np = w_max_abs_np / fp8_max
# `w_np` is the `w_full_np` divided by the `w_scale_np` (with block awareness),
# clamped to (-fp8_max, fp8_max), and cast to `quant_dtype`
w_np = np.zeros_like(w_full_np, dtype="float8_e4m3fn")
if len(w_scale_shape) == 2:
for i in range(w_scale_shape[-2]):
for j in range(w_scale_shape[-1]):
w_np[
i * block_size[0] : min((i + 1) * block_size[0], shape[-2]),
j * block_size[1] : min((j + 1) * block_size[1], shape[-1]),
] = np.clip(
w_full_np[
i * block_size[0] : min((i + 1) * block_size[0], shape[-2]),
j * block_size[1] : min((j + 1) * block_size[1], shape[-1]),
]
/ w_scale_np[..., i, j],
-fp8_max,
fp8_max,
)
else:
for e in range(w_scale_shape[0]):
for i in range(w_scale_shape[-2]):
for j in range(w_scale_shape[-1]):
w_np[
e,
i * block_size[0] : min((i + 1) * block_size[0], shape[-2]),
j * block_size[1] : min((j + 1) * block_size[1], shape[-1]),
] = np.clip(
w_full_np[
e,
i * block_size[0] : min((i + 1) * block_size[0], shape[-2]),
j * block_size[1] : min((j + 1) * block_size[1], shape[-1]),
]
/ w_scale_np[e, i, j],
-fp8_max,
fp8_max,
)
w_scale_np = np.random.rand(*w_scale_np.shape).astype("float32") / fp8_max
return w_np, w_scale_np
def blockwise_matmul(
x_fp8_np: np.ndarray,
x_scale_np: np.ndarray,
w_np: np.ndarray,
w_scale_np: np.ndarray,
block_size: Tuple[int, int],
dtype: str,
):
o_np = np.zeros((x_fp8_np.shape[0], w_np.shape[0]), dtype=dtype)
for j in range(w_scale_np.shape[0]):
for k in range(w_scale_np.shape[1]):
o_np[:, j * block_size[0] : min((j + 1) * block_size[0], w_np.shape[0])] += (
np.matmul(
x_fp8_np[
:, k * block_size[1] : min((k + 1) * block_size[1], x_fp8_np.shape[1])
].astype(dtype),
w_np[
j * block_size[0] : min((j + 1) * block_size[0], w_np.shape[0]),
k * block_size[1] : min((k + 1) * block_size[1], w_np.shape[1]),
].T.astype(dtype),
)
* x_scale_np[:, k : k + 1]
* w_scale_np[j, k]
)
return o_np
def blockwise_bmm(
x_fp8_np: np.ndarray,
x_scale_np: np.ndarray,
w_np: np.ndarray,
w_scale_np: np.ndarray,
block_size: Tuple[int, int],
dtype: str,
):
o_np = np.zeros((x_fp8_np.shape[0], x_fp8_np.shape[1], w_np.shape[1]), dtype=dtype)
for j in range(w_scale_np.shape[1]):
for k in range(w_scale_np.shape[2]):
o_np[..., j * block_size[0] : min((j + 1) * block_size[0], w_np.shape[1])] += (
np.matmul(
x_fp8_np[
..., k * block_size[1] : min((k + 1) * block_size[1], x_fp8_np.shape[2])
].astype(dtype),
w_np[
...,
j * block_size[0] : min((j + 1) * block_size[0], w_np.shape[1]),
k * block_size[1] : min((k + 1) * block_size[1], w_np.shape[2]),
]
.transpose(0, 2, 1)
.astype(dtype),
)
* x_scale_np[..., k : k + 1]
* w_scale_np[..., j : j + 1, k : k + 1]
)
return o_np
@tvm.testing.requires_cutlass
@tvm.testing.requires_cuda_compute_version(9)
def test_fp8_e4m3_groupwise_scaled_gemm():
M = 16
N = 4608
K = 896
block_size = (128, 128)
assert N % 128 == 0 and K % 128 == 0 # Only support N/K are multiple of 128
func_name = "cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn"
gemm_func = tvm.get_global_func(func_name, allow_missing=True)
if gemm_func is None:
print(f"Skipped as {func_name} is not available")
return
device = tvm.cuda(0)
dtype = "bfloat16"
x_np, x_scale_np = rowwise_quant_fp8_e4m3((M, K), block_size, dtype)
w_np, w_scale_np = blockwise_quant_fp8_e4m3((N, K), block_size, dtype)
o_np = blockwise_matmul(x_np, x_scale_np, w_np, w_scale_np, block_size, dtype)
x_tvm = tvm.runtime.tensor(x_np, device=device)
x_scale_tvm = tvm.runtime.tensor(x_scale_np.T, device=device)
w_tvm = tvm.runtime.tensor(w_np, device=device)
w_scale_tvm = tvm.runtime.tensor(w_scale_np, device=device)
workspace = tvm.runtime.empty((4096 * 1024,), dtype="uint8", device=device)
o_tvm = tvm.runtime.empty((M, N), dtype=dtype, device=device)
gemm_func(
x_tvm, w_tvm, x_scale_tvm, w_scale_tvm, workspace, block_size[0], block_size[1], o_tvm
)
o_tvm = o_tvm.numpy()
tvm.testing.assert_allclose(o_tvm, o_np, rtol=1e-4, atol=0.5)
@tvm.testing.requires_cutlass
@tvm.testing.requires_cuda_compute_version(9)
def test_fp8_e4m3_groupwise_scaled_bmm():
B = 16
M = 40
N = 512
K = 128
block_size = (128, 128)
assert N % 128 == 0 and K % 128 == 0 # Only support N/K are multiple of 128
func_name = "cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn"
gemm_func = tvm.get_global_func(func_name, allow_missing=True)
if gemm_func is None:
print(f"Skipped as {func_name} is not available")
return
device = tvm.cuda(0)
dtype = "bfloat16"
x_np, x_scale_np = rowwise_quant_fp8_e4m3((B, M, K), block_size, dtype)
w_np, w_scale_np = blockwise_quant_fp8_e4m3((B, N, K), block_size, dtype)
o_np = blockwise_bmm(x_np, x_scale_np, w_np, w_scale_np, block_size, dtype)
x_tvm = tvm.runtime.tensor(x_np, device=device)
x_scale_tvm = tvm.runtime.tensor(x_scale_np.transpose(0, 2, 1), device=device)
w_tvm = tvm.runtime.tensor(w_np, device=device)
w_scale_tvm = tvm.runtime.tensor(w_scale_np, device=device)
workspace = tvm.runtime.empty((4096 * 1024,), dtype="uint8", device=device)
o_tvm = tvm.runtime.empty((B, M, N), dtype=dtype, device=device)
gemm_func(
x_tvm, w_tvm, x_scale_tvm, w_scale_tvm, workspace, block_size[0], block_size[1], o_tvm
)
o_tvm = o_tvm.numpy()
tvm.testing.assert_allclose(o_tvm, o_np, rtol=1e-4, atol=0.5)
if __name__ == "__main__":
tvm.testing.main()