blob: 4ff498ae2b939a409247ead1010861ba5838f198 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest
import tvm
import tvm.testing
import tvm.topi.testing
from tvm import relax
from tvm.relax.backend.contrib.cublas import partition_for_cublas
from tvm.relax.testing import get_relax_matmul_module
from tvm.script import relax as R
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder
try:
import ml_dtypes
except ImportError:
ml_dtypes = None
@pytest.fixture(autouse=True)
def reset_seed():
np.random.seed(0)
pytestmark = tvm.testing.requires_cublas.marks()
def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False):
dev = tvm.device(target, 0)
with tvm.transform.PassContext(
config={
"relax.backend.use_cuda_graph": cuda_graph,
"relax.transform.apply_legalize_ops": legalize,
}
):
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, dev)
f = vm["main"]
inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
# For cuda graph, run the compiled function twice to make sure that we can launch the cached
# graph on the second run.
if cuda_graph:
f(*inputs)
return f(*inputs).numpy()
def get_result_with_relax_cublas_offload(mod, np_inputs, cuda_graph=False, bind_constants=False):
mod = partition_for_cublas(mod, bind_constants=bind_constants)
mod = relax.transform.RunCodegen()(mod)
return build_and_run(mod, np_inputs, "cuda", cuda_graph)
def _to_concrete_shape(symbolic_shape, var_table):
result = []
for dim in symbolic_shape:
if not isinstance(dim, tvm.tir.expr.Var):
result.append(dim)
continue
if dim not in var_table:
var_table[dim] = np.random.randint(10, 50)
result.append(var_table[dim])
return tuple(result)
_vars = {
"a": tvm.tir.expr.Var("a", "int64"),
"b": tvm.tir.expr.Var("b", "int64"),
}
_epilogue_table = {
"none": (False, None),
"bias": (True, None),
"relu": (True, R.nn.relu),
"gelu": (True, R.nn.gelu),
}
def get_relax_matmul_dequantize_module(
x_shape,
y_shape,
in_dtype,
out_dtype,
transposed_y=False,
scale_const=1.0,
zero_point_const=0.0,
):
"""Create a matmul op followd by dequantize operations."""
with IRBuilder() as builder:
with relax_builder.function():
R.func_name("main")
x = R.arg("x", R.Tensor(x_shape, in_dtype))
y = R.arg("y", R.Tensor(y_shape, in_dtype))
with R.dataflow() as frame:
if transposed_y:
axes = list(range(len(y_shape) - 2)) + [-1, -2]
y = R.emit(R.permute_dims(y, axes=axes))
result = R.emit(R.matmul(x, y, out_dtype="float32"))
result = R.emit(
R.dequantize(
result,
scale=R.const(scale_const, "float16"),
zero_point=R.const(zero_point_const, "float16"),
axis=-1,
out_dtype=out_dtype,
)
)
R.output(result)
R.func_ret_value(frame.output_vars[0])
func = builder.get()
return tvm.IRModule({"main": func})
@pytest.mark.parametrize(
"x_shape, y_shape, transpose_y, epilogue",
[
# Regular
((8, 8), (8, 8), False, "none"),
((_vars["a"], 6), (6, 16), False, "bias"),
# Transposed
((4, 16), (16, 128), True, "relu"),
((35, 8), (8, 8), True, "gelu"),
# # 3D x 3D
((6, 32, 8), (6, 8, 10), False, "bias"),
((6, 32, 8), (6, 8, 10), True, "none"),
((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"),
# ND x ND
((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"),
# ND x 2D
((5, 3, 32, 8), (8, 10), False, "none"),
],
)
@pytest.mark.parametrize(
"in_dtype, out_dtype",
[
("float16", "float16"),
("float32", "float32"),
],
)
def test_matmul_offload(
x_shape,
y_shape,
transpose_y,
epilogue,
in_dtype,
out_dtype,
):
with_bias, activation = _epilogue_table[epilogue]
var_table = {}
concrete_x_shape = _to_concrete_shape(x_shape, var_table)
concrete_y_shape = _to_concrete_shape(y_shape, var_table)
x = np.random.randn(*concrete_x_shape).astype(in_dtype)
y = np.random.randn(*concrete_y_shape).astype(in_dtype)
if transpose_y:
y = np.swapaxes(y, -2, -1)
y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])
if with_bias:
bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype)
args = (x, y, bias)
else:
bias = None
args = (x, y)
mod = get_relax_matmul_module(
x_shape,
y_shape,
in_dtype,
out_dtype,
bias_shape=bias.shape if with_bias else None,
transposed_y=transpose_y,
activation=activation,
)
out = get_result_with_relax_cublas_offload(mod, args)
ref = build_and_run(mod, args, "llvm", legalize=True)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize(
"x_shape, y_shape, transpose_y, epilogue",
[
# Regular
((8, 8), (8, 8), False, "none"),
((_vars["a"], 8), (8, 16), False, "none"),
# Transposed
((4, 16), (16, 128), True, "none"),
((35, 16), (16, 128), False, "none"),
# # 3D x 3D
((6, 32, 8), (6, 8, 12), False, "none"),
((6, 32, 8), (6, 8, 10), True, "none"),
((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "none"),
# ND x ND
((5, 3, 32, 8), (5, 3, 8, 12), False, "none"),
# ND x 2D
((5, 3, 32, 8), (8, 12), False, "none"),
],
)
def test_matmul_igemm_offload(
x_shape,
y_shape,
transpose_y,
epilogue,
):
in_dtype = "int8"
out_dtype = "int32"
with_bias, activation = _epilogue_table[epilogue]
var_table = {}
concrete_x_shape = _to_concrete_shape(x_shape, var_table)
concrete_y_shape = _to_concrete_shape(y_shape, var_table)
x = np.random.randn(*concrete_x_shape).astype(in_dtype)
y = np.random.randn(*concrete_y_shape).astype(in_dtype)
if transpose_y:
y = np.swapaxes(y, -2, -1)
y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])
if with_bias:
bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype)
args = (x, y, bias)
else:
bias = None
args = (x, y)
mod = get_relax_matmul_module(
x_shape,
y_shape,
in_dtype,
out_dtype,
bias_shape=bias.shape if with_bias else None,
transposed_y=transpose_y,
activation=activation,
)
out = get_result_with_relax_cublas_offload(mod, args)
ref = build_and_run(mod, args, "llvm", legalize=True)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@tvm.testing.requires_cuda_compute_version(9)
@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed")
@pytest.mark.parametrize(
"x_shape, y_shape, transpose_y, out_dtype",
[
((10, 32), (64, 32), True, "float32"),
((32, 16), (32, 16), True, "float16"),
((2, 10, 32), (2, 64, 32), True, "float32"),
],
)
def test_matmul_fp8_offload(
x_shape,
y_shape,
transpose_y,
out_dtype,
):
in_dtype = "e4m3_float8"
mod = get_relax_matmul_module(
x_shape,
y_shape,
in_dtype,
out_dtype,
bias_shape=None,
transposed_y=transpose_y,
activation=None,
)
numpytype = "float8_e4m3fn"
x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype)
y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype)
z = np.swapaxes(y, -2, -1) if transpose_y else y
args = (x, y)
out = get_result_with_relax_cublas_offload(mod, args)
ref_out = np.matmul(x, z).astype(out_dtype)
tvm.testing.assert_allclose(out, ref_out, rtol=1e-3, atol=1e-3)
@tvm.testing.requires_cuda_compute_version(9)
@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed")
def test_matmul_fp8_dequantize_offload():
x_shape = (10, 32)
y_shape = (64, 32)
in_dtype = "e4m3_float8"
mod = get_relax_matmul_dequantize_module(
x_shape,
y_shape,
in_dtype,
"float16",
transposed_y=True,
scale_const=0.34786,
zero_point_const=0.0,
)
numpytype = "float8_e4m3fn"
x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype)
y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype)
args = (x, y)
out = get_result_with_relax_cublas_offload(mod, args, bind_constants=True)
ref = build_and_run(mod, args, "llvm", legalize=True)
tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize(
"M, N, K, out_dtype, transposed_y, partition_done",
[
(15, 64, 32, "float32", True, True),
(15, 64, 32, "e4m3_float8", True, True),
(15, 64, 32, "e5m2_float8", True, False),
(16, 32, 60, "float32", True, False),
(16, 30, 64, "float32", True, False),
(16, 8, 16, "float16", True, True),
(16, 16, 16, "float16", False, False),
],
)
def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, transposed_y, partition_done):
mod = get_relax_matmul_module(
(M, K), (N, K), "e4m3_float8", out_dtype, transposed_y=transposed_y
)
mod = partition_for_cublas(mod)
func_name = "relax_matmul_cublas" if partition_done else "R.matmul"
assert func_name in mod["main"].script()
@pytest.mark.parametrize(
"M, N, K, scale, zp, num_bindings",
[
(16, 64, 32, 2.0, 0.0, 1),
(16, 64, 32, 2.0, 1.0, 2),
(16, 64, 32, [2.0] * 64, [2.0] * 64, 2),
],
)
def test_cublas_partition_fp8_matmul_dequantize(M, N, K, scale, zp, num_bindings):
mod = get_relax_matmul_dequantize_module(
(M, K),
(N, K),
"e4m3_float8",
"float16",
transposed_y=True,
scale_const=scale,
zero_point_const=zp,
)
mod = partition_for_cublas(mod)
# Check whether R.dequantize is still in main function or not
assert len(mod["main"].body.blocks[0].bindings) == num_bindings
def test_cublas_partition_matmul_without_bias():
# cuBLAS does not handle 2D bias (residual input)
mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16", bias_shape=(16, 32))
mod = partition_for_cublas(mod)
# R.add is still in the main function
assert len(mod["main"].body.blocks[0].bindings) == 2
@pytest.mark.parametrize(
"M, N, K, was_partitioned", [(16, 8, 32, True), (16, 8, 33, False), (16, 9, 32, False)]
)
def test_cublas_partition_igemm(M, N, K, was_partitioned):
mod = get_relax_matmul_module((M, K), (K, N), "int8", "int32")
mod = partition_for_cublas(mod)
func_name = "fused_relax_matmul_cublas" if was_partitioned else "R.matmul"
assert func_name in mod["main"].script()
def test_cublas_partition_igemm_with_bias():
mod = get_relax_matmul_module((16, 32), (32, 8), "int8", "int32", bias_shape=(8,))
mod = partition_for_cublas(mod)
func = mod["main"].script()
assert "fused_relax_matmul_cublas" in func and "R.add" in func
def test_cublas_matmul_cuda_graph():
@tvm.script.ir.ir_module
class Mod:
@R.function
def main(
x: R.Tensor((16, 16), "float16"),
w0: R.Tensor((16, 16), "float16"),
w1: R.Tensor((16, 16), "float16"),
w2: R.Tensor((16, 16), "float16"),
):
R.func_attr({"num_input": 1})
with R.dataflow():
lv0 = R.matmul(x, w0)
lv1 = R.matmul(lv0, w1)
lv2 = R.matmul(lv1, w2)
R.output(lv2)
return lv2
mod = Mod
shape = [16, 16]
data = np.random.rand(*shape).astype(np.float16)
w0 = np.random.rand(*shape).astype(np.float16)
w1 = np.random.rand(*shape).astype(np.float16)
w2 = np.random.rand(*shape).astype(np.float16)
inputs = (data, w0, w1, w2)
out = get_result_with_relax_cublas_offload(Mod, inputs, cuda_graph=True)
with tvm.target.Target("cuda"):
mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
ref = build_and_run(mod, inputs, "llvm", legalize=True)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
tvm.testing.main()