blob: c22f3f01a8805b005fd393f1df7a22d776cc7779 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
import pytest
import tvm
from tvm.script import tir as T
import numpy as np
import tvm.testing
from typing import List, Tuple
from tvm import DataType, DataTypeCode, IRModule
from tvm import dlight as dl
from tvm import relax, te, tir, topi
from tvm.relax.frontend import nn
from tvm.runtime import NDArray
from tvm.target import Target
from tvm.topi.utils import get_const_tuple
from tvm.script import ir as I, relax as R, tir as T
try:
import ml_dtypes
except ImportError:
ml_dtypes = None
@tvm.testing.requires_cuda_compute_version(9)
def test_e4m3_conversions():
dtype = "e4m3_float8"
@T.prim_func
def add(
A: T.Buffer((64,), dtype),
B: T.Buffer((64,), dtype),
C: T.Buffer((64,), dtype),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i in range(64):
with T.block("C"):
v_i = T.axis.spatial(64, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = T.Cast(dtype, T.Cast("float16", A[v_i]) + T.Cast("float16", B[v_i]))
sch = tvm.tir.Schedule(add)
block = sch.get_block("C")
b = sch.get_loops(block)
bx, tx = sch.split(b[0], factors=[None, 32])
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
target = "cuda"
fadd = tvm.build(sch.mod, target=target)
cuda_src = fadd.imported_modules[0].get_source()
assert "__nv_fp8_e4m3" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in generated CUDA"
dev = tvm.device(target, 0)
numpytype = "float8_e4m3fn"
a = tvm.nd.array(np.random.uniform(low=0, high=5, size=64).astype(numpytype), dev)
b = tvm.nd.array(np.random.uniform(low=0, high=5, size=64).astype(numpytype), dev)
c = tvm.nd.array(np.zeros(64, dtype=numpytype), dev)
fadd(a, b, c)
tvm.testing.assert_allclose(
c.numpy().astype("float16"), (a.numpy() + b.numpy()).astype("float16")
)
@tvm.testing.requires_cuda_compute_version(9)
def test_e4m3_packing():
length = 64
vector_length = 4
native_dtype, packed_dtype = ("e4m3_float8x4", "uint32")
@T.prim_func
def add(
A: T.Buffer((length,), native_dtype),
R: T.Buffer((length,), packed_dtype),
B: T.Buffer((length,), native_dtype),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i in range(length):
with T.block("R"):
v_i = T.axis.spatial(length, i)
T.reads(A[v_i])
T.writes(R[v_i])
R[v_i] = T.reinterpret(packed_dtype, A[v_i])
for i in range(length):
with T.block("B"):
v_i = T.axis.spatial(length, i)
T.reads(R[v_i])
T.writes(B[v_i])
B[v_i] = T.reinterpret(native_dtype, R[v_i])
sch = tvm.tir.Schedule(add)
block = sch.get_block("R")
b = sch.get_loops(block)
bx, tx = sch.split(b[0], factors=[None, 32])
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
block = sch.get_block("B")
b = sch.get_loops(block)
bx, tx = sch.split(b[0], factors=[None, 32])
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
target = "cuda"
f = tvm.build(sch.mod, target=target)
dev = tvm.device(target, 0)
numpytype = "float8_e4m3fn"
np_shape = (length, vector_length)
a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype)
a = tvm.nd.empty(shape=(length,), dtype=native_dtype, device=dev)
r = tvm.nd.empty(shape=(length,), dtype=packed_dtype, device=dev)
b = tvm.nd.empty(shape=(length,), dtype=native_dtype, device=dev)
a.copyfrom(a_np)
f(a, r, b)
tvm.testing.assert_allclose(a.numpy().astype("float16"), b.numpy().astype("float16"))
native_dtype, promoted_dtype = tvm.testing.parameters(
("e4m3_float8", "float32"),
("e4m3_float8", "float16"),
("e4m3_float8x2", "float32x2"),
("e4m3_float8x2", "float16x2"),
("e4m3_float8x4", "float32x4"),
# Supported via half4 vector type extension in codegen
("e4m3_float8x4", "float16x4"),
)
@tvm.testing.requires_cuda_compute_version(9)
def test_e4m3_vector_conversions(native_dtype, promoted_dtype):
vector_length = 64
@T.prim_func
def add(
A: T.Buffer((vector_length,), native_dtype),
B: T.Buffer((vector_length,), native_dtype),
C: T.Buffer((vector_length,), native_dtype),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i in range(vector_length):
with T.block("C"):
v_i = T.axis.spatial(vector_length, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = T.Cast(
native_dtype, T.Cast(promoted_dtype, A[v_i]) + T.Cast(promoted_dtype, B[v_i])
)
sch = tvm.tir.Schedule(add)
block = sch.get_block("C")
b = sch.get_loops(block)
bx, tx = sch.split(b[0], factors=[None, 32])
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
target = "cuda"
fadd = tvm.build(sch.mod, target=target)
cuda_src = fadd.imported_modules[0].get_source()
dev = tvm.device(target, 0)
numpytype = "float8_e4m3fn"
if "x" in native_dtype:
lanes = int(native_dtype.split("x")[-1])
else:
lanes = 1
if "x" in promoted_dtype:
promoted_base_dtype = promoted_dtype.split("x")[0]
else:
promoted_base_dtype = promoted_dtype
np_shape = (vector_length, lanes) if lanes > 1 else (vector_length,)
a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype)
a = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev)
a.copyfrom(a_np)
b_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype)
b = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev)
b.copyfrom(b_np)
c = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev)
fadd(a, b, c)
tvm.testing.assert_allclose(
c.numpy().astype(promoted_base_dtype), (a_np + b_np).astype(promoted_base_dtype)
)
bcast_length = tvm.testing.parameter(2, 4, 6, 8)
@tvm.testing.requires_cuda_compute_version(8)
def test_half_broadcast(bcast_length):
dtype = "float16"
@T.prim_func
def vector_broadcast(a: T.Buffer[(), dtype], vec: T.Buffer[(bcast_length,), dtype]):
for t in range(1):
with T.block("broadcast"):
vec[0:bcast_length] = T.broadcast(a[()], bcast_length)
sch = tvm.tir.Schedule(vector_broadcast)
block = sch.get_block("broadcast")
b = sch.get_loops(block)
bx, tx = sch.split(b[0], factors=[None, 1])
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
target = "cuda"
func = tvm.build(sch.mod, target=target)
dev = tvm.device(target, 0)
a_np = np.random.uniform(low=0, high=4, size=()).astype(dtype)
a = tvm.nd.array(a_np, device=dev)
b = tvm.nd.empty((bcast_length,), dtype=dtype, device=dev)
func(a, b)
b_np = np.full((bcast_length,), a_np)
tvm.testing.assert_allclose(b.numpy(), b_np)
vector_length = tvm.testing.parameter(2, 4)
@tvm.testing.requires_cuda_compute_version(8)
def test_half_misaligned_vector_load(vector_length):
dtype = "float16"
vec_dtype = dtype + "x" + str(vector_length)
length = 256
@T.prim_func
def vector_load(
A: T.Buffer[(length,), dtype], B: T.Buffer[(length // vector_length,), vec_dtype]
):
for b in T.thread_binding(1, thread="blockIdx.x"):
for i in T.thread_binding(length // vector_length, thread="threadIdx.x"):
vec_index = T.ramp((i + 1) * vector_length - 1, -1, vector_length)
B[i] = A[vec_index]
target = "cuda"
f = tvm.build(vector_load, target=target)
dev = tvm.device(target, 0)
a_np = np.random.uniform(low=0, high=1, size=(length,)).astype(dtype)
a = tvm.nd.array(a_np, device=dev)
b = tvm.nd.empty((length // vector_length,), dtype=vec_dtype, device=dev)
f(a, b)
b_np = np.empty((length // vector_length, vector_length), dtype=dtype)
for i in range(length // vector_length):
start_index = (i + 1) * vector_length - 1
b_np[i, :] = a_np[start_index - vector_length + 1 : start_index + 1][::-1]
tvm.testing.assert_allclose(b.numpy(), b_np)
@tvm.testing.requires_cuda_compute_version(8)
def test_half4_vector_add():
dtype = "float16"
length = 64
vector_length = 4
vec_dtype = dtype + "x" + str(vector_length)
@T.prim_func
def add(
A: T.Buffer((length,), vec_dtype),
B: T.Buffer((length,), vec_dtype),
C: T.Buffer((length,), vec_dtype),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i in range(length):
with T.block("C"):
v_i = T.axis.spatial(length, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = A[v_i] + B[v_i]
sch = tvm.tir.Schedule(add)
block = sch.get_block("C")
b = sch.get_loops(block)
bx, tx = sch.split(b[0], factors=[None, 32])
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
target = "cuda"
fadd = tvm.build(sch.mod, target=target)
dev = tvm.device(target, 0)
a_np = np.random.uniform(-1, 1, (length, vector_length)).astype(dtype)
a = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev)
a.copyfrom(a_np)
b_np = np.random.uniform(-1, 1, (length, vector_length)).astype(dtype)
b = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev)
b.copyfrom(b_np)
c = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev)
fadd(a, b, c)
c_expected = a_np + b_np
tvm.testing.assert_allclose(c.numpy(), c_expected, atol=1e-5, rtol=1e-5)
class BaseFP8E4M3QuantScaleOnly:
@classmethod
def create_quantize_func(
cls,
weight_shape,
model_dtype,
quantize_dtype,
storage_dtype,
group_size,
num_elem_per_storage,
max_int_value,
axis,
output_transpose,
) -> IRModule:
if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float:
quantize_func = cls.quantize_fp8x4_e4m3
else:
assert NotImplementedError()
bb = relax.BlockBuilder() # pylint: disable=invalid-name
weight_var = relax.Var("weight", relax.TensorStructInfo(weight_shape, model_dtype))
compute_scale, compute_quantize, compute_transpose = quantize_func(
weight_shape,
model_dtype,
quantize_dtype,
storage_dtype,
group_size,
num_elem_per_storage,
max_int_value,
axis,
output_transpose,
)
with bb.function(name="main", params=[weight_var]):
with bb.dataflow():
lv_scale = bb.emit_te(compute_scale, weight_var)
lv_quantized_weight = compute_quantize(bb, (weight_var, lv_scale))
if compute_transpose:
lv_output = bb.emit_te(compute_transpose, lv_quantized_weight, lv_scale)
lv_quantized_weight = lv_output[0]
lv_scale = lv_output[1]
tuple_output = bb.emit((lv_quantized_weight, lv_scale))
gv = bb.emit_output(tuple_output)
bb.emit_func_output(gv)
return bb.finalize()
@classmethod
def create_dequantize_func(
cls,
packed_weight_shape,
scale_shape,
dequantized_shape,
model_dtype,
quantize_dtype,
storage_dtype,
group_size,
num_elem_per_storage,
axis,
) -> IRModule:
if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float:
dequantize_func = cls.dequantize_fp8x4_e4m3
else:
assert NotImplementedError()
bb = relax.BlockBuilder() # pylint: disable=invalid-name
packed_weight_var = relax.Var(
"weight", relax.TensorStructInfo(packed_weight_shape, storage_dtype)
)
scale_var = relax.Var("scale", relax.TensorStructInfo(scale_shape, model_dtype))
compute_dequantize = dequantize_func(
packed_weight_shape,
scale_shape,
dequantized_shape,
model_dtype,
quantize_dtype,
storage_dtype,
group_size,
num_elem_per_storage,
axis,
)
with bb.function(name="main", params=[packed_weight_var, scale_var]):
with bb.dataflow():
lv = compute_dequantize(bb, (packed_weight_var, scale_var))
gv = bb.emit_output(lv)
bb.emit_func_output(gv)
return bb.finalize()
@classmethod
def quantize_fp8x4_e4m3( # pylint: disable=too-many-locals
cls,
weight_shape: List[tir.PrimExpr],
model_dtype,
quantize_dtype,
storage_dtype,
group_size,
num_elem_per_storage,
max_int_value,
axis: int = -1,
output_transpose: bool = False,
) -> Tuple[te.Tensor, te.Tensor]:
"""Group quantization for weight tensor, defined in tensor expression."""
max_int = tir.const(max_int_value, model_dtype)
shape = weight_shape # pylint: disable=invalid-name
axis = axis if axis >= 0 else len(shape) + axis
k = shape[axis]
quantize_dtype = DataType(quantize_dtype)
# compute scale per group
r = te.reduce_axis((0, group_size), name="r") # pylint: disable=invalid-name
num_group = tir.ceildiv(k, group_size)
# (4096, 4096) -> quantize axis = 0, group size = 32 -> (128, 4096)
# for channel quant group_size = 4096 -> (1, 4096)
scale_shape = (*shape[:axis], num_group, *shape[axis + 1 :])
def compute_scale(weight: te.Tensor):
min_scaling_factor = tir.const(1.0 / (max_int_value * 512.0), model_dtype)
max_abs = te.compute(
shape=scale_shape,
fcompute=lambda *idx: te.max(
tir.if_then_else(
idx[axis] * group_size + r < k,
te.abs(weight(*idx[:axis], idx[axis] * group_size + r, *idx[axis + 1 :])),
te.min_value(model_dtype),
),
axis=r,
),
name="max_abs_value",
)
scale = te.compute(
scale_shape,
lambda *idx: te.max(
max_abs(*idx).astype(model_dtype) / max_int, min_scaling_factor
),
name="scale",
)
return scale
def compute_quantize_weight(bb: relax.BlockBuilder, args: relax.expr.Expr):
# compute scaled weight
packed_shape = (weight_shape[0], weight_shape[1] // num_elem_per_storage)
quant = cls.quant_and_pack_fp8x4_e4m3_sm90(
weight_shape,
packed_shape,
scale_shape,
group_size,
axis,
model_dtype,
storage_dtype,
quantize_dtype,
)
# quant.show()
global_var = bb.add_func(quant, "quantized_weight")
lv_quantized_weight = bb.emit(
relax.call_tir(
global_var, args, relax.TensorStructInfo(packed_shape, storage_dtype)
)
)
return lv_quantized_weight
compute_transpose = None
if output_transpose:
def compute_transpose(quantized_weight: te.Tensor, scale: te.Tensor):
if len(quantized_weight.shape) != 2 or len(scale.shape) != 2:
raise ValueError(
"Does not support transpose output quantized weight with ndim != 2"
)
quantized_weight = topi.transpose(quantized_weight)
scale = topi.transpose(scale)
return quantized_weight, scale
return compute_scale, compute_quantize_weight, compute_transpose
@classmethod
def dequantize_fp8x4_e4m3( # pylint: disable=too-many-locals
cls,
packed_weight_shape: List[tir.PrimExpr],
scale_shape,
dequant_shape,
model_dtype,
quantize_dtype,
storage_dtype,
group_size,
num_elem_per_storage,
axis: int = -1,
) -> Tuple[te.Tensor, te.Tensor]:
"""Group quantization for weight tensor, defined in tensor expression."""
axis = axis if axis >= 0 else len(shape) + axis
def compute_dequantize_weight(bb: relax.BlockBuilder, args: relax.expr.Expr):
dequant = cls.dequant_fp8x4_e4m3_sm90(
packed_weight_shape,
scale_shape,
dequant_shape,
group_size,
axis,
model_dtype,
storage_dtype,
quantize_dtype,
)
global_var = bb.add_func(dequant, "dequantize_weight")
lv_dequantized_weight = bb.emit(
relax.call_tir(global_var, args, relax.TensorStructInfo(dequant_shape, model_dtype))
)
return lv_dequantized_weight
return compute_dequantize_weight
@classmethod
def quant_and_pack_fp8x4_e4m3_sm90(
cls,
weight_shape,
packed_shape,
scale_shape,
group_size,
axis,
model_dtype,
storage_dtype,
quantized_dtype,
):
vector_length = 4
vec_quantized_dtype = f"{quantized_dtype}x{vector_length}"
vec_model_dtype = f"{model_dtype}x{vector_length}"
num_elem_per_storage = vector_length
# TODO(csullivan) assert on storage dtype / quantize type bytes == vector length
assert (
group_size % vector_length == 0
), f"Number of elements in a group must be divisible by fp8 vector length {vector_length}"
@T.prim_func(private=True)
def quant_pack(
A: T.Buffer(weight_shape, model_dtype),
scale: T.Buffer(scale_shape, model_dtype),
compute: T.Buffer(
packed_shape,
storage_dtype,
),
):
# with T.block("root"):
# test = T.alloc_buffer(1, dtype=vec_model_dtype, scope="local")
for i0, i1 in T.grid(
T.int64(weight_shape[0]), T.int64(weight_shape[1] // vector_length)
):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(
A[v_i0, v_i1 : v_i1 + vector_length],
scale[v_i0, v_i1 * T.int64(vector_length) // T.int64(group_size)],
)
T.writes(compute[v_i0, v_i1 * vector_length])
compute[v_i0, v_i1] = T.reinterpret(
storage_dtype,
T.Cast(
vec_quantized_dtype,
A[v_i0, T.ramp(v_i1 * vector_length, 1, vector_length)]
/ scale[v_i0, v_i1 * T.int64(vector_length) // T.int64(group_size)],
),
)
return quant_pack
@classmethod
def dequant_fp8x4_e4m3_sm90(
cls,
packed_weight_shape,
scale_shape,
out_shape,
group_size,
axis,
model_dtype,
storage_dtype,
quantized_dtype,
):
vector_length = 4
vec_quantized_dtype = f"{quantized_dtype}x{vector_length}"
vec_model_dtype = f"{model_dtype}x{vector_length}"
num_elem_per_storage = vector_length
@T.prim_func
def dequant(
packed_weight: T.Buffer(packed_weight_shape, storage_dtype),
scale: T.Buffer(scale_shape, model_dtype),
dequantize: T.Buffer(out_shape, model_dtype),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(packed_weight_shape[0]), T.int64(packed_weight_shape[1])):
with T.block("dequantize"):
v_i0 = T.axis.spatial(T.int64(packed_weight_shape[0]), i0)
v_i1 = T.axis.spatial(T.int64(packed_weight_shape[1]), i1)
T.reads(
packed_weight[v_i0, v_i1],
scale[v_i0, v_i1 * T.int64(vector_length) // T.int64(group_size)],
)
dequantize[v_i0, T.ramp(v_i1 * vector_length, 1, vector_length)] = T.Cast(
vec_model_dtype,
T.reinterpret(vec_quantized_dtype, packed_weight[v_i0, v_i1]),
) * T.Broadcast(
scale[v_i0, v_i1 * T.int64(vector_length) // T.int64(group_size)],
vector_length,
)
return dequant
@classmethod
def compile_quant_and_dequant_by_scale(
cls,
weight_shape,
scales_shape,
quant_weight_shape,
model_dtype,
quantize_dtype,
storage_dtype,
group_size,
num_el_per_storage,
max_int_value,
axis,
target_str,
dev,
):
quant_mod = cls.create_quantize_func(
weight_shape,
model_dtype,
quantize_dtype,
storage_dtype,
group_size,
num_el_per_storage,
max_int_value,
axis,
output_transpose=False,
)
# quant_mod.show()
target = tvm.target.Target(target_str)
with target:
quant_mod = dl.ApplyDefaultSchedule(
dl.gpu.Reduction(),
dl.gpu.GeneralReduction(),
dl.gpu.Fallback(),
)(quant_mod)
ex_1 = relax.build(quant_mod, target=target)
vm_1 = relax.VirtualMachine(ex_1, dev)
dequant_mod = cls.create_dequantize_func(
quant_weight_shape,
scales_shape,
weight_shape,
model_dtype,
quantize_dtype,
storage_dtype,
group_size,
num_el_per_storage,
axis,
)
# dequant_mod.show()
with target:
dequant_mod = dl.ApplyDefaultSchedule(
dl.gpu.Reduction(),
dl.gpu.GeneralReduction(),
dl.gpu.Fallback(),
)(dequant_mod)
dequant_mod.show()
ex_2 = relax.build(dequant_mod, target=target)
vm_2 = relax.VirtualMachine(ex_2, dev)
def print_cuda(target, mod, name=None):
if name:
mod = mod[name]
f = tvm.build(mod, target=target)
cuda_src = f.imported_modules[0].get_source()
print(cuda_src)
print_cuda(target, dequant_mod, name="dequant")
return vm_1["main"], vm_2["main"]
class TestFP8e4x4QuantDequantScale(BaseFP8E4M3QuantScaleOnly):
# weight_shape = tvm.testing.parameter((32000, 4096), (4096, 14336))
weight_shape = tvm.testing.parameter((128, 256), (128, 64))
@tvm.testing.fixture
def group_size(self):
return 64
@tvm.testing.fixture
def axis(self):
return 1
@tvm.testing.fixture
def model_dtype(self):
return "float16"
@tvm.testing.fixture
def storage_dtype(self):
return "uint32"
@tvm.testing.fixture
def quantize_dtype(self):
return "e4m3_float8"
@tvm.testing.fixture
def num_el_per_storage(self):
return 4
@tvm.testing.fixture
def max_int_value(self):
return 448
@tvm.testing.fixture
def target_str(self):
return "cuda"
@tvm.testing.fixture
def scale_shape(self, weight_shape, group_size, axis):
return [
(d + group_size - 1) // group_size if axis == i else d
for i, d in enumerate(weight_shape)
]
@tvm.testing.fixture
def quant_weight_shape(self, weight_shape, num_el_per_storage, axis):
return [
(d + num_el_per_storage - 1) // num_el_per_storage if axis == i else d
for i, d in enumerate(weight_shape)
]
@tvm.testing.fixture
def compiled_functions(
self,
weight_shape,
scale_shape,
quant_weight_shape,
model_dtype,
quantize_dtype,
storage_dtype,
group_size,
num_el_per_storage,
max_int_value,
axis,
target_str,
):
dev = tvm.device(target_str, 0)
return self.compile_quant_and_dequant_by_scale(
weight_shape,
scale_shape,
quant_weight_shape,
model_dtype,
quantize_dtype,
storage_dtype,
group_size,
num_el_per_storage,
max_int_value,
axis,
target_str,
dev,
)
@tvm.testing.requires_cuda_compute_version(9)
def test_main(self, weight_shape, model_dtype, target_str, compiled_functions):
quant, dequant = compiled_functions
dev = tvm.device(target_str, 0)
weight_np = np.random.uniform(-100, 100, weight_shape).astype(model_dtype)
weight = tvm.nd.array(weight_np, device=dev)
quant_weight, scales = quant(weight)
quant_weight_np, scales_np = quant_weight.numpy(), scales.numpy()
dequant_weight = dequant(quant_weight, scales)
dequant_weight_np = dequant_weight.numpy()
tvm.testing.assert_allclose(weight_np, dequant_weight_np, atol=10, rtol=5e-2)
@tvm.testing.requires_cuda_compute_version(9)
@pytest.mark.parametrize("dtype", ["e5m2_float8", "e4m3_float8"])
def test_const(dtype):
@T.prim_func
def func(A: T.Buffer((4,), dtype)) -> None:
A_local = T.alloc_buffer((4,), dtype=dtype, scope="local")
for tx in T.thread_binding(0, 4, "threadIdx.x"):
for i in T.vectorized(4):
A_local[i] = T.float32(1.0).astype(dtype)
A[tx] = A_local[tx]
mod = tvm.IRModule({"main": func})
tvm.build(mod, target="cuda")
num_experts = 8
reduce_size = 1792
spatial_size = 4096
@tvm.testing.requires_cuda_compute_version(9)
@pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes to be installed")
def test_moe_gemv_shfl_down_illegal_instr():
global num_experts
global reduce_size
global spatial_size
@I.ir_module
class SingleBatchMoE_float8_e4m3:
@T.prim_func(private=True)
def moe_dequantize_gemv(
x_handle: T.handle,
w: T.Buffer((num_experts, spatial_size, reduce_size), "e4m3_float8"),
scale: T.Buffer((1,), "float16"),
indptr: T.Buffer((1, 2), "int32"),
o: T.Buffer((2, spatial_size), "float16"),
):
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
num_seq = T.int64()
x = T.match_buffer(x_handle, (num_seq, reduce_size), "float16")
for expert_id in T.thread_binding(2, thread="blockIdx.y"):
with T.block("gemv_o"):
e = T.axis.spatial(2, expert_id)
T.reads(
w[indptr[0, e], 0:spatial_size, 0:reduce_size],
indptr[0, e],
scale[0],
x[e, 0:reduce_size],
)
T.writes(o[e, 0:spatial_size])
y = T.alloc_buffer((spatial_size, reduce_size), "float16")
for i1, i2 in T.grid(spatial_size, reduce_size):
with T.block("dequantize"):
i, j = T.axis.remap("SS", [i1, i2])
T.reads(w[indptr[0, e], i, j], indptr[0, e], scale[0])
T.writes(y[i, j])
y[i, j] = T.Cast("float16", w[indptr[0, e], i, j]) * scale[0]
for i1, i2 in T.grid(spatial_size, reduce_size):
with T.block("gemv"):
i, j = T.axis.remap("SR", [i1, i2])
T.reads(x[e, j], y[i, j])
T.writes(o[e, i])
with T.init():
o[e, i] = T.float16(0)
o[e, i] = o[e, i] + x[e, j] * y[i, j]
@R.function
def main(
x: R.Tensor(("num_seq", reduce_size), dtype="float16"),
indptr: R.Tensor((1, 2), dtype="int32"),
weight: R.Tensor((num_experts, spatial_size, reduce_size), dtype="e4m3_float8"),
scale: R.Tensor((1,), dtype="float32"),
) -> R.Tensor((2, spatial_size), dtype="float16"):
num_seq = T.int64()
R.func_attr({"num_input": 2})
cls = SingleBatchMoE_float8_e4m3
with R.dataflow():
astype: R.Tensor((1,), dtype="float16") = R.astype(scale, dtype="float16")
lv = R.call_tir(
cls.moe_dequantize_gemv,
(x, weight, astype, indptr),
out_sinfo=R.Tensor((2, spatial_size), dtype="float16"),
)
gv: R.Tensor((2, spatial_size), dtype="float16") = lv
R.output(gv)
return gv
def _pipeline(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
seq = tvm.transform.Sequential(
[
tvm.relax.transform.LegalizeOps(),
tvm.dlight.ApplyDefaultSchedule(
tvm.dlight.gpu.Matmul(),
tvm.dlight.gpu.GEMV(),
tvm.dlight.gpu.Reduction(),
tvm.dlight.gpu.GeneralReduction(),
tvm.dlight.gpu.Fallback(),
),
]
)
mod = seq(mod)
return mod
mod = SingleBatchMoE_float8_e4m3
target = tvm.target.Target("cuda")
with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": False}) and target:
mod = _pipeline(mod)
rt_mod = tvm.relax.build(mod, target=target)
dev = tvm.cuda(0)
x_data = np.zeros((1, reduce_size), dtype=np.float16)
x = tvm.nd.array(x_data, device=dev)
indptr_data = np.zeros((1, 2), dtype=np.int32)
indptr = tvm.nd.array(indptr_data, device=dev)
weight_data = np.zeros((num_experts, spatial_size, reduce_size), dtype="float8_e4m3fn")
weight = tvm.nd.array(weight_data, device=dev)
scale_data = np.zeros((1,), dtype=np.float32)
scale = tvm.nd.array(scale_data, device=dev)
vm = relax.VirtualMachine(rt_mod, dev)
# Ensure this runs without failure. Utilizing dlight thread extents TS, TR = 4, 64
# in GEMV scheduling will yield: CUDA: an illegal instruction was encountered.
vm["main"](x, indptr, weight, scale)
if __name__ == "__main__":
tvm.testing.main()