| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import numpy as np |
| import pytest |
| |
| import tvm |
| import tvm.testing |
| import tvm.topi.testing |
| from tvm import relax |
| from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul |
| from tvm.contrib.pickle_memoize import memoize |
| from tvm.relax.backend.contrib.cutlass import partition_for_cutlass |
| from tvm.relax.testing import get_relax_matmul_module |
| from tvm.script import ir as I |
| from tvm.script import relax as R |
| from tvm.script import tir as T |
| from tvm.script.ir_builder import IRBuilder |
| from tvm.script.ir_builder import relax as relax_builder |
| |
| |
| @pytest.fixture(autouse=True) |
| def reset_seed(): |
| np.random.seed(0) |
| |
| |
| @tvm.script.ir_module |
| class Conv2dBiasReLU: |
| @R.function |
| def main( |
| data: R.Tensor((16, 32, 32, 16), "float16"), |
| weight: R.Tensor((32, 3, 3, 16), "float16"), |
| bias: R.Tensor((1, 1, 1, 32), "float16"), |
| ): |
| with R.dataflow(): |
| conv1 = R.nn.relu( |
| R.nn.conv2d(data, weight, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI") |
| + bias, |
| ) |
| R.output(conv1) |
| |
| return conv1 |
| |
| |
| @tvm.script.ir_module |
| class Conv2dx2: |
| @R.function |
| def main( |
| data: R.Tensor((16, 32, 32, 8), "float16"), |
| weight1: R.Tensor((8, 3, 3, 8), "float16"), |
| weight2: R.Tensor((8, 3, 3, 8), "float16"), |
| ): |
| with R.dataflow(): |
| conv1 = relax.op.nn.conv2d( |
| data, weight1, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" |
| ) |
| conv2 = relax.op.nn.conv2d( |
| conv1, weight2, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" |
| ) |
| R.output(conv2) |
| |
| return conv2 |
| |
| |
| pytestmark = tvm.testing.requires_cutlass.marks() |
| |
| |
| def build_and_run(mod, inputs_np, target, legalize=True, cuda_graph=False): |
| with tvm.transform.PassContext( |
| config={ |
| "relax.backend.use_cuda_graph": cuda_graph, |
| "relax.transform.apply_legalize_ops": legalize, |
| } |
| ): |
| ex = relax.build(mod, target) |
| |
| dev = tvm.device(target, 0) |
| vm = relax.VirtualMachine(ex, dev) |
| f = vm["main"] |
| inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] |
| |
| # For cuda graph, run the compiled function twice to make sure that we can launch the cached |
| # graph on the second run. |
| if cuda_graph: |
| f(*inputs) |
| |
| return f(*inputs).numpy() |
| |
| |
| def build_cutlass(mod, assert_all_bindings_fused=True, num_final_bindings=1): |
| mod = partition_for_cutlass(mod) |
| |
| if assert_all_bindings_fused: |
| assert len(mod["main"].body.blocks[0].bindings) == num_final_bindings |
| |
| codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}}) |
| mod = codegen_pass(mod) |
| return mod |
| |
| |
| def get_result_with_relax_cutlass_offload( |
| mod, *args, assert_all_bindings_fused=True, num_final_bindings=1 |
| ): |
| mod = build_cutlass(mod, assert_all_bindings_fused, num_final_bindings) |
| return build_and_run(mod, args, "cuda") |
| |
| |
| def test_kernel_sharing(): |
| low, high = -1, 1 |
| data_np = np.random.randint(low, high, size=(16, 32, 32, 8)).astype("float16") |
| weight1_np = np.random.randint(low, high, size=(8, 3, 3, 8)).astype("float16") |
| weight2_np = np.random.randint(low, high, size=(8, 3, 3, 8)).astype("float16") |
| |
| out = get_result_with_relax_cutlass_offload( |
| Conv2dx2, data_np, weight1_np, weight2_np, assert_all_bindings_fused=False |
| ) |
| ref = build_and_run(Conv2dx2, [data_np, weight1_np, weight2_np], "llvm") |
| |
| np.testing.assert_equal(out, ref) |
| |
| |
| def get_relax_conv2d_module( |
| data_shape, |
| weight_shape, |
| dtype, |
| with_bias=False, |
| activation=None, |
| residual_bin_op=None, |
| residual_activation=None, |
| ): |
| with IRBuilder() as builder: |
| with relax_builder.function(): |
| R.func_name("main") |
| data = R.arg("data", R.Tensor(data_shape, dtype)) |
| weight = R.arg("weight", R.Tensor(weight_shape, dtype)) |
| if with_bias: |
| bias = R.arg("bias", R.Tensor((1, 1, 1, weight_shape[0]), dtype)) |
| |
| with R.dataflow() as frame: |
| output = R.emit( |
| R.nn.conv2d( |
| data, |
| weight, |
| out_dtype=dtype, |
| padding=(1, 1), |
| data_layout="NHWC", |
| kernel_layout="OHWI", |
| ) |
| ) |
| if with_bias: |
| output = R.emit(output + bias) |
| if activation is not None: |
| output = R.emit(activation(output)) |
| if residual_bin_op is not None: |
| output = R.emit(residual_bin_op(output, data)) |
| if residual_activation is not None: |
| output = R.emit(residual_activation(output)) |
| R.output(output) |
| |
| R.func_ret_value(frame.output_vars[0]) |
| |
| func = builder.get() |
| return tvm.IRModule({"main": func}) |
| |
| |
| def _to_concrete_shape(symbolic_shape, var_table=None): |
| if var_table is None: |
| var_table = {} |
| |
| result = [] |
| for dim in symbolic_shape: |
| if isinstance(dim, tuple): |
| result.append(_to_concrete_shape(dim, var_table)) |
| continue |
| |
| if not isinstance(dim, tvm.tir.expr.Var): |
| result.append(dim) |
| continue |
| |
| if dim not in var_table: |
| var_table[dim] = np.random.randint(10, 50) |
| result.append(var_table[dim]) |
| |
| return tuple(result) |
| |
| |
| _vars = { |
| "a": tvm.tir.expr.Var("a", "int64"), |
| "b": tvm.tir.expr.Var("b", "int64"), |
| } |
| |
| |
| _epilogue_table = { |
| "none": (False, None), |
| "bias": (True, None), |
| "relu": (True, R.nn.relu), |
| "gelu": (True, R.nn.gelu), |
| "silu": (True, R.nn.silu), |
| } |
| |
| |
| _residual_block_table = { |
| "none": (None, None), |
| "add_relu": (R.add, R.nn.relu), |
| "mul_relu": (R.multiply, R.nn.relu), |
| "add": (R.add, None), |
| "mul": (R.multiply, None), |
| } |
| |
| |
| @pytest.mark.parametrize( |
| "data_shape, weight_shape, dtype, epilogue, residual_block", |
| [ |
| # Regular |
| ((16, 32, 32, 16), (32, 3, 3, 16), "float16", "none", "none"), |
| ((40, 128, 50, 16), (16, 2, 2, 16), "float16", "bias", "none"), |
| ((3, 64, 64, 128), (32, 1, 1, 128), "float16", "relu", "none"), |
| ((12, 32, 32, 16), (45, 5, 5, 16), "float16", "silu", "none"), |
| # residual block |
| ((3, 64, 64, 16), (16, 3, 3, 16), "float16", "relu", "add"), |
| ((16, 32, 32, 16), (16, 3, 3, 16), "float16", "relu", "mul_relu"), |
| ((40, 128, 50, 16), (16, 3, 3, 16), "float16", "bias", "add_relu"), |
| ((128, 32, 32, 16), (16, 3, 3, 16), "float16", "silu", "mul"), |
| ], |
| ) |
| def test_conv2d_offload(data_shape, weight_shape, dtype, epilogue, residual_block): |
| low, high = -1, 1 |
| data = np.random.randint(low, high, size=data_shape).astype(dtype) |
| weight = np.random.randint(low, high, size=weight_shape).astype(dtype) |
| bias = np.random.randint(low, high, size=(1, 1, 1, weight_shape[0])).astype(dtype) |
| |
| with_bias, activation = _epilogue_table[epilogue] |
| residual_bin_op, residual_activation = _residual_block_table[residual_block] |
| |
| if with_bias: |
| args = (data, weight, bias) |
| else: |
| args = (data, weight) |
| |
| mod = get_relax_conv2d_module( |
| data_shape, |
| weight_shape, |
| dtype, |
| with_bias=with_bias, |
| activation=activation, |
| residual_bin_op=residual_bin_op, |
| residual_activation=residual_activation, |
| ) |
| out = get_result_with_relax_cutlass_offload(mod, *args) |
| |
| ref = build_and_run(mod, args, "llvm") |
| |
| tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) |
| |
| |
| @pytest.mark.parametrize( |
| "data_shape, weight_shape, dtype", |
| [ |
| # batch dynamism |
| ((T.Var("n", "int64"), 32, 32, 16), (32, 3, 3, 16), "float16"), |
| # channel dynamism |
| ((16, 32, 32, T.Var("c", "int64")), (32, 3, 3, T.Var("c", "int64")), "float16"), |
| ], |
| ) |
| def test_conv2d_dynamic(data_shape, weight_shape, dtype): |
| # Create dynamic conv2d module. |
| mod = get_relax_conv2d_module( |
| data_shape, |
| weight_shape, |
| dtype, |
| ) |
| # Attempt to offload to cutlass, should run without an error |
| # but not offload due to incompatibility. |
| mod = build_cutlass(mod) |
| # Check that no cutlass call is introduced (until we support dynamism). |
| assert "call_dps" not in str(mod.__repr__()) |
| |
| |
| def test_cutlass_partition_conv2d_residual_blocked(): |
| @tvm.script.ir_module |
| class Conv2dReLU: |
| """ |
| This conv2d should not be fused as conv2d residual block, because both lhs and rhs of |
| the last R.add depends on the result of conv2d. |
| """ |
| |
| @R.function |
| def main( |
| data: R.Tensor((32, 3, 3, 16), "float32"), |
| weight: R.Tensor((16, 3, 3, 16), "float32"), |
| bias: R.Tensor((1, 1, 1, 16), "float32"), |
| ): |
| with R.dataflow(): |
| conv1 = R.nn.conv2d( |
| data, |
| weight, |
| padding=(1, 1), |
| data_layout="NHWC", |
| kernel_layout="OHWI", |
| ) |
| out = R.nn.relu(conv1 + bias) |
| # residual depends on conv result, which cannot be handled in cutlass |
| result = out + out |
| R.output(result) |
| |
| return result |
| |
| mod = partition_for_cutlass(Conv2dReLU, annotate_codegen=False) |
| for f_var in mod.functions: |
| func = mod[f_var] |
| if "Composite" in func.attrs: |
| # verify that the function is not fused as residual block |
| assert func.attrs["Composite"] == "cutlass.conv2d_bias_relu" |
| |
| |
| @pytest.mark.parametrize( |
| "x_shape, y_shape, transpose_y, epilogue, residual_block", |
| [ |
| # Regular |
| ((32, 6), (6, 16), False, "none", "none"), |
| ((_vars["a"], 6), (6, 16), False, "bias", "none"), |
| # Transposed |
| ((4, 16), (16, 128), True, "relu", "none"), |
| ((35, 8), (8, 8), True, "gelu", "none"), |
| # 3D x 3D |
| ((6, 32, 8), (6, 8, 10), False, "bias", "none"), |
| ((6, 32, 8), (6, 8, 10), True, "none", "none"), |
| ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu", "none"), |
| # 3D x 2D |
| ((6, 32, 8), (8, 10), False, "none", "none"), |
| ((_vars["a"], 32, 8), (8, 10), False, "bias", "none"), |
| ((10, 16, 8), (8, 10), True, "relu", "none"), |
| # 2D x 3D |
| ((32, 8), (10, 8, 10), False, "relu", "none"), |
| ((32, 8), (_vars["a"], 8, 10), True, "gelu", "none"), |
| # ND x 2D |
| ((3, 6, 32, 8), (8, 10), False, "bias", "none"), |
| ((_vars["a"], _vars["b"], 6, 32, 8), (8, 10), False, "none", "none"), |
| # 2D x ND |
| ((32, 8), (5, 3, 8, 10), False, "gelu", "none"), |
| # ND x ND |
| ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu", "none"), |
| ((3, 2, 4, 16, 15), (1, 1, 15, 2), True, "gelu", "none"), |
| ((1, 1, 16, 15), (3, 2, _vars["a"], 15, 2), False, "none", "none"), |
| # Residual |
| ((32, 8), (8, 8), False, "bias", "add"), |
| ((4, 16), (16, 16), True, "relu", "add_relu"), |
| ((8, 32, 8), (8, 8, 8), False, "bias", "add"), |
| ((5, 3, 32, 8), (8, 8), True, "relu", "add"), |
| # Residual fusion without bias - this is supported via the matmul + bias pattern |
| # where bias == residual input |
| ((4, 16), (16, 16), False, "none", "add"), |
| ], |
| ) |
| @pytest.mark.parametrize( |
| "dtype", |
| [ |
| "float16", |
| ], |
| ) |
| def test_matmul_offload( |
| x_shape, |
| y_shape, |
| transpose_y, |
| epilogue, |
| residual_block, |
| dtype, |
| ): |
| with_bias, activation = _epilogue_table[epilogue] |
| var_table = {} |
| concrete_x_shape = _to_concrete_shape(x_shape, var_table) |
| concrete_y_shape = _to_concrete_shape(y_shape, var_table) |
| x = np.random.randn(*concrete_x_shape).astype(dtype) |
| y = np.random.randn(*concrete_y_shape).astype(dtype) |
| |
| if transpose_y: |
| y = np.swapaxes(y, -2, -1) |
| y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2]) |
| |
| if with_bias: |
| bias = np.random.randn(concrete_y_shape[-1]).astype(dtype) |
| args = (x, y, bias) |
| else: |
| bias = None |
| args = (x, y) |
| |
| residual_bin_op, residual_activation = _residual_block_table[residual_block] |
| |
| mod = get_relax_matmul_module( |
| x_shape, |
| y_shape, |
| dtype, |
| bias_shape=bias.shape if with_bias else None, |
| transposed_y=transpose_y, |
| activation=activation, |
| residual_bin_op=residual_bin_op, |
| residual_activation=residual_activation, |
| ) |
| out = get_result_with_relax_cutlass_offload(mod, *args) |
| ref = build_and_run(mod, args, "llvm") |
| |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| def test_matmul_with_3d_bias_offload(): |
| x_shape = (1, 4, 8) |
| y_shape = (1, 8, 16) |
| dtype = "float16" |
| |
| x = np.random.randn(*x_shape).astype(dtype) |
| y = np.random.randn(*y_shape).astype(dtype) |
| bias = np.random.randn(1, x_shape[-2], y_shape[-1]).astype(dtype) |
| args = (x, y, bias) |
| |
| @tvm.script.ir_module |
| class Mod: |
| @R.function |
| def main( |
| x: R.Tensor((1, 4, 8), "float16"), |
| y: R.Tensor((1, 8, 16), "float16"), |
| bias: R.Tensor((1, 4, 16), "float16"), |
| ): |
| with R.dataflow(): |
| lv1 = R.matmul(x, y) |
| gv1 = lv1 + bias |
| R.output(gv1) |
| |
| return gv1 |
| |
| out = get_result_with_relax_cutlass_offload(Mod, *args) |
| ref = build_and_run(Mod, args, "llvm", legalize=True) |
| |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| @pytest.mark.parametrize( |
| "x_shape, y_shape, expected", |
| [ |
| # Regular matmul |
| ((3, 4), (4, 5), True), |
| # Batch matmul without stretching |
| ((3, 16, 15), (3, 15, 2), True), |
| ((_vars["a"], 16, 15), (_vars["a"], 15, 2), True), |
| # Broadcast 2D to 3D |
| ((3, 16, 15), (15, 2), True), |
| ((_vars["a"], 16, 15), (15, 2), True), |
| ((16, 15), (3, 15, 2), True), |
| # Broadcast one-length dimension |
| ((1, 16, 15), (3, 15, 2), True), |
| ((3, 16, 15), (1, 15, 2), True), |
| ((1, 1, 16, 15), (3, 2, 4, 15, 2), True), |
| ((1, 1, 16, 15), (3, _vars["a"], 4, 15, 2), True), |
| # ND x ND |
| ((3, 2, 4, 16, 15), (3, 2, 4, 15, 2), True), |
| ((_vars["a"], 2, 4, 16, 15), (_vars["a"], 2, 4, 15, 2), True), |
| ( |
| (_vars["a"], _vars["b"], 4, 16, 15), |
| (_vars["a"], _vars["b"], 4, 15, 2), |
| True, |
| ), |
| # ND x ND with one-length dimension |
| ((1, 2, 4, 16, 15), (1, 2, 4, 15, 2), True), |
| ((3, 2, 1, 16, 15), (3, 2, 1, 15, 2), True), |
| # Extra one-length dimension doesn't block broadcasting |
| ((3, 2, 1, 16, 15), (1, 1, 3, 2, 1, 15, 2), True), |
| # Not broadcasting all dims. Cannot be computed by stride-based batch gemm |
| ((3, 1, 1, 16, 15), (3, 2, 4, 15, 2), False), |
| ((3, 2, 4, 16, 15), (2, 4, 15, 2), False), |
| # Different shape |
| ((3, 4, 16, 15), (3, 2, 15, 2), False), |
| ((3, _vars["a"], 16, 15), (3, _vars["b"], 15, 2), False), |
| # Cannot prove that broadcast dimensions are equal |
| ((_vars["a"], 16, 15), (3, 15, 2), False), |
| ((3, _vars["a"], 1, 16, 15), (1, 1, 3, 2, 1, 15, 2), False), |
| # Reduction axis must be constant |
| ((3, _vars["a"]), (_vars["a"], 5), False), |
| ], |
| ) |
| def test_is_shape_valid_for_cutlass_matmul(x_shape, y_shape, expected): |
| assert is_shape_valid_for_cutlass_matmul(x_shape, y_shape) == expected |
| |
| |
| @pytest.mark.parametrize( |
| "x_shape, y_shape, transpose_y, dtype", |
| [ |
| # Not broadcasting all dims. Cannot be computed by stride-based batch gemm |
| ((3, 1, 1, 16, 15), (3, 2, 4, 15, 2), False, "float16"), |
| ((3, 2, _vars["a"], 16, 15), (3, 2, 4, 15, 2), False, "float16"), |
| ((1, 2, 1, 16, 15), (2, 1, 4, 15, 2), False, "float16"), |
| ((3, 2, 4, 16, 15), (2, 4, 15, 2), True, "float16"), |
| ((3, 16, 15), (2, 1, 3, 15, 2), True, "float16"), |
| ((3, 16, 15), (_vars["a"], 1, 3, 15, 2), True, "float16"), |
| ((_vars["a"], 1, 3, 16, 15), (_vars["b"], 1, 3, 15, 2), True, "float16"), |
| ((_vars["a"], _vars["b"], 3, 16, 15), (_vars["a"], 1, 3, 15, 2), True, "float16"), |
| ], |
| ) |
| def test_cutlass_partition_matmul_blocked(x_shape, y_shape, transpose_y, dtype): |
| if transpose_y: |
| y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2]) |
| |
| mod = get_relax_matmul_module(x_shape, y_shape, dtype, transposed_y=transpose_y) |
| mod = partition_for_cutlass(mod) |
| |
| assert len(mod.functions) == 1 |
| |
| |
| def test_cutlass_partition_matmul_tuple_return_blocked(): |
| @tvm.script.ir_module |
| class TransposedMatmul: |
| @R.function |
| def main( |
| x: R.Tensor((4, 4), "float32"), |
| y: R.Tensor((4, 4), "float32"), |
| ): |
| with R.dataflow(): |
| lv1 = R.permute_dims(y) |
| # Because lv1 is used by both lv2 and out, it should stay out of |
| # the fused function. Otherwise the fused function will return |
| # tuple output, which isn't possible in cutlass, e.g. |
| # @R.function |
| # def fused_relax_permute_dims_relax_matmul(...): |
| # R.func_attr({"Composite": "cutlass.matmul_transposed", "Primitive": 1}) |
| # with R.dataflow(): |
| # gv: R.Tensor((4, 4), dtype="float32") = R.permute_dims(y, axes=None) |
| # gv1: R.Tensor((4, 4), dtype="float32") = R.matmul(x, gv, out_dtype="void") |
| # R.output(gv, gv1) |
| # return (gv, gv1) # Cannot get `gv` if dispatch to cutlass kernel. |
| lv2 = R.matmul(x, lv1) |
| out = R.matmul(lv1, lv2) |
| R.output(out) |
| |
| return out |
| |
| mod = partition_for_cutlass(TransposedMatmul, annotate_codegen=False) |
| for f_var in mod.functions: |
| func = mod[f_var] |
| if "Composite" in func.attrs: |
| # verify that the function is not fused as transposed matmul |
| assert func.attrs["Composite"] == "cutlass.matmul" |
| |
| |
| def test_cutlass_partition_matmul_cyclic_dependency_blocked(): |
| @tvm.script.ir_module |
| class Module: |
| @R.function |
| def main(x: R.Tensor((128, 128), "float16"), w: R.Tensor((128, 128), "float16")): |
| with R.dataflow(): |
| # Because lv1 depends on lv, this block should be fused as matmul instead of matmul_bias. |
| lv = R.matmul(x, w) |
| lv1 = R.power(lv, R.const(2.0, "float16")) |
| lv2 = R.add(lv, lv1) |
| R.output(lv2) |
| return lv2 |
| |
| mod = partition_for_cutlass(Module, annotate_codegen=False) |
| for f_var in mod.functions: |
| func = mod[f_var] |
| if "Composite" in func.attrs: |
| assert func.attrs["Composite"] == "cutlass.matmul" |
| |
| |
| @pytest.fixture(params=["float16", "float32"]) |
| def attention_dtype(request): |
| return request.param |
| |
| |
| @pytest.fixture( |
| params=[ |
| # B, S, N, H |
| (32, (_vars["a"], 8), 16, (8, 8)), |
| (32, (8, 8), 16, (8, 8)), |
| (4, (16, 8), 32, (8, 8)), # s != s_kv |
| (4, (16, 8), 32, (8, 16)), # h != h_v |
| (32, (8, 8), 16, (4, 4)), # h is not aligned |
| (2, (8, 8), 8, (256, 256)), # needs output accumulator buffer |
| ] |
| ) |
| def attention_size(request): |
| return request.param |
| |
| |
| def get_relax_attention_module( |
| q_shape, |
| k_shape, |
| v_shape, |
| *, |
| dtype, |
| bias_shape=None, |
| qk_scale=None, |
| causal_mask=None, |
| window_size=None, |
| ): |
| from tvm.script.ir_builder import IRBuilder |
| from tvm.script.ir_builder import relax as relax_builder |
| from tvm.script.ir_builder import tir as T |
| |
| if qk_scale is not None: |
| qk_scale = T.FloatImm("float32", qk_scale) |
| |
| if window_size is not None: |
| window_size = T.IntImm("int32", window_size) |
| |
| with IRBuilder() as builder: |
| with relax_builder.function(): |
| R.func_name("main") |
| q = R.arg("q", R.Tensor(q_shape, dtype)) |
| k = R.arg("k", R.Tensor(k_shape, dtype)) |
| v = R.arg("v", R.Tensor(v_shape, dtype)) |
| bias = None |
| if bias_shape is not None and bias_shape != "none": |
| bias = R.arg("bias", R.Tensor(bias_shape, dtype)) |
| |
| with R.dataflow() as frame: |
| result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, causal_mask, window_size)) |
| R.output(result) |
| |
| R.func_ret_value(frame.output_vars[0]) |
| |
| func = builder.get() |
| return tvm.IRModule({"main": func}) |
| |
| |
| def get_numpy_attention_ref( |
| b, |
| s, |
| s_kv, |
| n, |
| h, |
| h_v, |
| bias_shape, |
| qk_scale, |
| causal, |
| dtype, |
| window_size=None, |
| num_kv_head=None, |
| ): |
| if num_kv_head is None: |
| num_kv_head = n |
| |
| q = np.random.randn(b, s, n, h).astype(dtype) |
| k_orig = np.random.randn(b, s_kv, num_kv_head, h).astype(dtype) |
| v_orig = np.random.randn(b, s_kv, num_kv_head, h_v).astype(dtype) |
| |
| if num_kv_head is None: |
| k = k_orig |
| v = v_orig |
| else: |
| factor = n // num_kv_head |
| k = np.repeat(k_orig, factor, axis=2) |
| v = np.repeat(v_orig, factor, axis=2) |
| |
| qt = q.transpose(0, 2, 1, 3) # b, n, s, h |
| kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv |
| if not qk_scale == "none": |
| score = qt @ kt * qk_scale # b, n, s, s_kv |
| else: |
| score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv |
| if not bias_shape == "none": |
| bias = np.random.randn(*bias_shape).astype(dtype) |
| score = score + bias # b, n, s, s_kv |
| else: |
| bias = None |
| if causal == "none": |
| attn = tvm.topi.testing.softmax_python(score, -1) |
| else: |
| if causal == "TopLeft": |
| offset = 0 |
| elif causal == "BottomRight": |
| offset = abs(s - s_kv) |
| else: |
| raise NotImplementedError() |
| score_masked = np.tril(score, k=offset) |
| |
| if window_size: |
| score_masked = np.triu(score_masked, -window_size + 1) |
| |
| score_masked_exp = np.tril( |
| np.exp(score_masked - np.max(score_masked, axis=-1, keepdims=True)), k=offset |
| ) |
| |
| if window_size: |
| score_masked_exp = np.triu(score_masked_exp, -window_size + 1) |
| |
| score_masked_sum = np.sum(score_masked_exp, axis=-1, keepdims=True) |
| attn = np.divide(score_masked_exp, score_masked_sum) |
| |
| vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v |
| ref = attn @ vt # b, n, s, h_v |
| return q, k_orig, v_orig, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v |
| |
| |
| def test_attention_offload(attention_size, attention_dtype): |
| b, (s, s_kv), n, (h, h_v) = attention_size |
| concrete_s, concrete_s_kv = _to_concrete_shape((s, s_kv)) |
| q, k, v, _, ref = get_numpy_attention_ref( |
| b, concrete_s, concrete_s_kv, n, h, h_v, "none", "none", "none", attention_dtype |
| ) |
| |
| q_shape = (b, s, n, h) |
| k_shape = (b, s_kv, n, h) |
| v_shape = (b, s_kv, n, h_v) |
| |
| mod = get_relax_attention_module(q_shape, k_shape, v_shape, dtype=attention_dtype) |
| out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) |
| |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| @pytest.fixture( |
| params=[ |
| # B, S, N, H, bias_shape |
| (4, (16, 8), 32, (8, 16), (4, 32, 16, 8)), |
| (4, (16, 8), 32, (8, 16), (4, 1, 16, 8)), |
| (4, (16, 8), 32, (8, 16), (4, 32, 1, 8)), |
| (4, (16, 8), 32, (8, 16), (4, 1, 1, 8)), |
| (4, (16, 8), 32, (8, 16), (1, 32, 16, 8)), |
| (4, (16, 8), 32, (8, 16), (1, 1, 16, 8)), |
| (4, (16, 8), 32, (8, 16), (1, 32, 1, 8)), |
| (4, (16, 8), 32, (8, 16), (1, 1, 1, 8)), |
| ] |
| ) |
| def attention_bias_size(request): |
| return request.param |
| |
| |
| def test_attention_bias_offload(attention_bias_size): |
| b, (s, s_kv), n, (h, h_v), bias_shape = attention_bias_size |
| concrete_s, concrete_s_kv, concrete_bias_shape = _to_concrete_shape((s, s_kv, bias_shape)) |
| |
| q, k, v, bias, ref = get_numpy_attention_ref( |
| b, concrete_s, concrete_s_kv, n, h, h_v, concrete_bias_shape, "none", "none", "float32" |
| ) |
| |
| q_shape = (b, s, n, h) |
| k_shape = (b, s_kv, n, h) |
| v_shape = (b, s_kv, n, h_v) |
| |
| mod = get_relax_attention_module( |
| q_shape, k_shape, v_shape, bias_shape=bias_shape, dtype="float32" |
| ) |
| out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3) |
| |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| @pytest.fixture( |
| params=[ |
| # B, S, N, H, bias_shape |
| (4, (16, 8), 32, (8, 16), (4, 32, 16, 8)), |
| (4, (16, 8), 32, (8, 16), "none"), |
| ] |
| ) |
| def attention_scale_size(request): |
| return request.param |
| |
| |
| @pytest.fixture(params=[0.01, 1e-8, -0.5, 1.23]) |
| def attention_scale(request): |
| return request.param |
| |
| |
| def test_attention_scale_offload(attention_scale_size, attention_scale): |
| b, (s, s_kv), n, (h, h_v), bias_shape = attention_scale_size |
| q, k, v, bias, ref = get_numpy_attention_ref( |
| b, s, s_kv, n, h, h_v, bias_shape, attention_scale, "none", "float32" |
| ) |
| |
| q_shape = (b, s, n, h) |
| k_shape = (b, s_kv, n, h) |
| v_shape = (b, s_kv, n, h_v) |
| |
| mod = get_relax_attention_module( |
| q_shape, k_shape, v_shape, dtype="float32", bias_shape=bias_shape, qk_scale=attention_scale |
| ) |
| if bias is None: |
| out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) |
| else: |
| out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3) |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| @pytest.fixture( |
| params=[ |
| # B, S, N, H, bias_shape |
| (2, (16, 8), 4, (8, 16), "none"), |
| (2, (8, 16), 4, (8, 16), "none"), |
| (2, (16, 8), 4, (8, 16), (2, 4, 16, 8)), |
| ] |
| ) |
| def attention_causal_size(request): |
| return request.param |
| |
| |
| @pytest.fixture(params=["TopLeft", "BottomRight"]) |
| def attention_causal(request): |
| return request.param |
| |
| |
| def test_attention_causal_offload(attention_causal_size, attention_causal): |
| b, (s, s_kv), n, (h, h_v), bias_shape = attention_causal_size |
| q, k, v, bias, ref = get_numpy_attention_ref( |
| b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float16" |
| ) |
| |
| q_shape = (b, s, n, h) |
| k_shape = (b, s_kv, n, h) |
| v_shape = (b, s_kv, n, h_v) |
| |
| mod = get_relax_attention_module( |
| q_shape, |
| k_shape, |
| v_shape, |
| dtype="float16", |
| bias_shape=bias_shape, |
| causal_mask=attention_causal, |
| ) |
| |
| if bias is None: |
| out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) |
| else: |
| out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3) |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| @memoize("topi.tests.test_codegen_cutlass.test_stacked_attention_offload") |
| def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, qk_scale, dtype): |
| qkv = np.random.randn(b, s, n * h + n * h + n * h_v).astype(dtype) |
| split_qkv = np.split(qkv, [n * h, n * h * 2], axis=2) |
| q = np.reshape(split_qkv[0], (b, s, n, h)) |
| k = np.reshape(split_qkv[1], (b, s, n, h)) |
| v = np.reshape(split_qkv[2], (b, s, n, h_v)) |
| qt = q.transpose(0, 2, 1, 3) # b, n, s, h |
| kt = k.transpose(0, 2, 3, 1) # b, n, h, s |
| if not qk_scale == "none": |
| score = qt @ kt * qk_scale # b, n, s, s |
| else: |
| score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s |
| if not bias_shape == "none": |
| bias = np.random.randn(*bias_shape).astype(dtype) |
| score = score + bias # b, n, s, s |
| else: |
| bias = None |
| attn = tvm.topi.testing.softmax_python(score, -1) |
| vt = v.transpose(0, 2, 1, 3) # b, n, s, h_v |
| ref = attn @ vt # b, n, s, h_v |
| return qkv, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v |
| |
| |
| def get_relax_stacked_attention_module( |
| qkv, b, s, n, h, h_v, op, bias=None, qk_scale=None, single_shape=False |
| ): |
| dtype = str(qkv.dtype) |
| |
| from tvm.script.ir_builder import IRBuilder |
| from tvm.script.ir_builder import relax as relax_builder |
| from tvm.script.ir_builder import tir as T |
| |
| if qk_scale is not None: |
| qk_scale = T.FloatImm("float32", qk_scale) |
| |
| if single_shape: |
| qk_shape = R.shape([b, s, n, h]) |
| v_shape = qk_shape |
| else: |
| qk_shape = [b, s, n, h] |
| v_shape = [b, s, n, h_v] |
| |
| with IRBuilder() as builder: |
| with relax_builder.function(): |
| R.func_name("main") |
| qkv = R.arg("qkv", R.Tensor(qkv.shape, dtype)) |
| if bias is not None: |
| bias = R.arg("bias", R.Tensor(bias.shape, dtype)) |
| with R.dataflow() as frame: |
| if op == "split": |
| qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2) |
| q = R.reshape(qkv_tuple[0], qk_shape) |
| k = R.reshape(qkv_tuple[1], qk_shape) |
| v = R.reshape(qkv_tuple[2], v_shape) |
| elif op == "strided_slice": |
| q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h], [1]), qk_shape) |
| k = R.reshape(R.strided_slice(qkv, [2], [n * h], [n * h * 2], [1]), qk_shape) |
| v = R.reshape( |
| R.strided_slice(qkv, [2], [n * h * 2], [n * h * 2 + n * h_v], [1]), v_shape |
| ) |
| else: |
| raise NotImplementedError() |
| result = R.emit(R.nn.attention(q, k, v, bias, qk_scale)) |
| R.output(result) |
| |
| R.func_ret_value(frame.output_vars[0]) |
| |
| func = builder.get() |
| return tvm.IRModule({"main": func}) |
| |
| |
| @pytest.fixture( |
| params=[ |
| # B, S, N, H, bias_shape, scale, single_shape |
| (4, 8, 32, (64, 32), "none", "none", False), |
| (4, 8, 32, (64, 32), (4, 32, 8, 8), 0.5, False), |
| (4, 8, 32, (64, 64), "none", "none", True), |
| ] |
| ) |
| def stacked_attention_size(request): |
| return request.param |
| |
| |
| def test_stacked_attention_split_offload(stacked_attention_size): |
| b, s, n, (h, h_v), bias_shape, scale, single_shape = stacked_attention_size |
| qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, scale, "float16") |
| if scale == "none": |
| mod = get_relax_stacked_attention_module( |
| qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape |
| ) |
| else: |
| mod = get_relax_stacked_attention_module( |
| qkv, b, s, n, h, h_v, "split", bias, scale, single_shape=single_shape |
| ) |
| |
| if bias is None: |
| out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=3) |
| else: |
| out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=3) |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| def test_stacked_attention_strided_slice_offload(stacked_attention_size): |
| b, s, n, (h, h_v), bias_shape, scale, single_shape = stacked_attention_size |
| qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, scale, "float32") |
| if scale == "none": |
| mod = get_relax_stacked_attention_module( |
| qkv, b, s, n, h, h_v, "strided_slice", bias, single_shape=single_shape |
| ) |
| else: |
| mod = get_relax_stacked_attention_module( |
| qkv, b, s, n, h, h_v, "strided_slice", bias, scale, single_shape=single_shape |
| ) |
| if bias is None: |
| out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=3) |
| else: |
| out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=3) |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| @pytest.fixture( |
| params=[ |
| # B, S, N, H, bias_shape, scale |
| (4, (16, 8), 32, (8, 16), "none", 0.5), |
| (4, (16, 8), 32, (8, 16), (4, 32, 16, 8), 0.5), |
| (4, (16, 8), "none", (8, 16), "none", 0.5), |
| (4, (16, 8), "none", (8, 16), (4, 32, 16, 8), 0.5), |
| ] |
| ) |
| def attention_rewrite_size(request): |
| return request.param |
| |
| |
| def get_relax_attention_rewrite_module( |
| q_shape, k_shape, v_shape, out_shape, dtype, bias_shape=None, scale=None |
| ): |
| from tvm.script.ir_builder import IRBuilder |
| from tvm.script.ir_builder import relax as relax_builder |
| from tvm.script.ir_builder import tir as T |
| |
| with IRBuilder() as builder: |
| with relax_builder.function(): |
| R.func_name("main") |
| q = R.arg("q", R.Tensor(q_shape, dtype)) |
| k = R.arg("k", R.Tensor(k_shape, dtype)) |
| v = R.arg("v", R.Tensor(v_shape, dtype)) |
| if bias_shape is not None: |
| bias = R.arg("bias", R.Tensor(bias_shape, dtype)) |
| with R.dataflow() as frame: |
| if len(q_shape) == 4: |
| q = R.emit(R.permute_dims(q, axes=[0, 2, 1, 3])) |
| q = R.emit(R.reshape(q, [q_shape[0] * q_shape[2], q_shape[1], q_shape[3]])) |
| |
| if len(k_shape) == 4: |
| k = R.emit(R.permute_dims(k, axes=[0, 2, 1, 3])) |
| k = R.emit(R.reshape(k, [k_shape[0] * k_shape[2], k_shape[1], k_shape[3]])) |
| |
| if len(v_shape) == 4: |
| v = R.emit(R.permute_dims(v, axes=[0, 2, 1, 3])) |
| v = R.emit(R.reshape(v, [v_shape[0] * v_shape[2], v_shape[1], v_shape[3]])) |
| |
| k = R.emit(R.permute_dims(k, axes=[0, 2, 1])) |
| qk = R.emit(R.matmul(q, k)) |
| qk_scaled = R.emit(R.multiply(qk, R.const(scale, "float32"))) |
| if bias_shape is not None: |
| if len(bias_shape) == 4: |
| bias = R.emit( |
| R.reshape(bias, [bias_shape[0] * bias_shape[1], *bias_shape[2:]]) |
| ) |
| qk_added = R.emit(R.add(qk_scaled, bias)) |
| softmax = R.emit(R.nn.softmax(qk_added, axis=-1)) |
| else: |
| softmax = R.emit(R.nn.softmax(qk_scaled, axis=-1)) |
| out = R.emit(R.matmul(softmax, v)) |
| |
| if len(out_shape) == 4: |
| out = R.emit( |
| R.reshape( |
| out, |
| [out_shape[0], out_shape[2], out_shape[1], out_shape[3]], |
| ) |
| ) |
| out = R.emit(R.permute_dims(out, axes=[0, 2, 1, 3])) |
| R.output(out) |
| |
| R.func_ret_value(frame.output_vars[0]) |
| |
| original_func = builder.get() |
| |
| if scale is not None: |
| scale = T.FloatImm("float32", scale) |
| |
| with IRBuilder() as builder: |
| with relax_builder.function(): |
| R.func_name("main") |
| q = R.arg("q", R.Tensor(q_shape, dtype)) |
| k = R.arg("k", R.Tensor(k_shape, dtype)) |
| v = R.arg("v", R.Tensor(v_shape, dtype)) |
| if bias_shape is not None: |
| bias = R.arg("bias", R.Tensor(bias_shape, dtype)) |
| with R.dataflow() as frame: |
| if len(q_shape) == 3: |
| q = R.emit(R.reshape(q, [q_shape[0], q_shape[1], 1, q_shape[2]])) |
| |
| if len(k_shape) == 3: |
| k = R.emit(R.reshape(k, [k_shape[0], k_shape[1], 1, k_shape[2]])) |
| |
| if len(v_shape) == 3: |
| v = R.emit(R.reshape(v, [v_shape[0], v_shape[1], 1, v_shape[2]])) |
| |
| if bias_shape is not None: |
| if len(bias_shape) == 4: |
| bias = R.emit( |
| R.reshape( |
| bias, |
| [ |
| bias_shape[0] * bias_shape[1], |
| bias_shape[2], |
| bias_shape[3], |
| ], |
| ) |
| ) |
| bias = R.emit( |
| R.reshape( |
| bias, |
| [ |
| bias_shape[0], |
| bias_shape[1], |
| bias_shape[2], |
| bias_shape[3], |
| ], |
| ) |
| ) |
| elif len(bias_shape) == 3: |
| bias = R.emit( |
| R.reshape(bias, [bias_shape[0], 1, bias_shape[1], bias_shape[2]]) |
| ) |
| else: |
| bias = None |
| out = R.emit(R.nn.attention(q, k, v, bias, scale)) |
| |
| if len(out_shape) == 3: |
| out = R.emit(R.reshape(out, [out_shape[0], out_shape[1], out_shape[2]])) |
| R.output(out) |
| |
| R.func_ret_value(frame.output_vars[0]) |
| |
| expected_func = builder.get() |
| return tvm.IRModule({"main": original_func}), tvm.IRModule({"main": expected_func}) |
| |
| |
| def get_numpy_attention_input(q_shape, k_shape, v_shape, bias_shape, dtype): |
| q = np.random.randn(*q_shape).astype(dtype) |
| k = np.random.randn(*k_shape).astype(dtype) |
| v = np.random.randn(*v_shape).astype(dtype) |
| if not bias_shape == "none": |
| bias = np.random.randn(*bias_shape).astype(dtype) |
| else: |
| bias = None |
| return q, k, v, bias |
| |
| |
| def test_attention_rewrite_offload(attention_rewrite_size): |
| b, (s, s_kv), n, (h, h_v), bias_shape, scale = attention_rewrite_size |
| q_shape = [b, s, n, h] if n != "none" else [b, s, h] |
| k_shape = [b, s_kv, n, h] if n != "none" else [b, s_kv, h] |
| v_shape = [b, s_kv, n, h_v] if n != "none" else [b, s_kv, h_v] |
| out_shape = [b, s, n, h_v] if n != "none" else [b, s, h_v] |
| bias_shape = [b, n, s, s_kv] if n != "none" else [b, s, s_kv] |
| q, k, v, bias = get_numpy_attention_input(q_shape, k_shape, v_shape, bias_shape, "float32") |
| original_mod, expected_mod = get_relax_attention_rewrite_module( |
| q_shape, k_shape, v_shape, out_shape, "float32", bias_shape, scale |
| ) |
| original_mod = partition_for_cutlass(original_mod, True) |
| expected_mod = partition_for_cutlass(expected_mod, True) |
| tvm.ir.assert_structural_equal(original_mod, expected_mod, True) |
| |
| codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}}) |
| original_mod = codegen_pass(original_mod) |
| expected_mod = codegen_pass(expected_mod) |
| if bias is None: |
| original_out = build_and_run(original_mod, [q, k, v], "cuda") |
| expected_out = build_and_run(expected_mod, [q, k, v], "cuda") |
| tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5, atol=1e-5) |
| else: |
| original_out = build_and_run(original_mod, [q, k, v, bias], "cuda", legalize=False) |
| expected_out = build_and_run(expected_mod, [q, k, v, bias], "cuda", legalize=False) |
| tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5, atol=1e-5) |
| |
| |
| def test_conv2d_residual_broadcast(): |
| data_shape = (2, 64, 64, 8) |
| weight_shape = (8, 3, 3, 8) |
| dtype = "float16" |
| |
| def get_mod(residual_batch): |
| with IRBuilder() as builder: |
| with relax_builder.function(): |
| R.func_name("main") |
| data = R.arg("data", R.Tensor(data_shape, dtype)) |
| weight = R.arg("weight", R.Tensor(weight_shape, dtype)) |
| bias = R.arg("bias", R.Tensor((1, 1, weight_shape[0]), dtype)) |
| residual = R.arg( |
| "residual", R.Tensor((residual_batch, 1, 1, weight_shape[0]), dtype) |
| ) |
| |
| with R.dataflow() as frame: |
| output = R.emit( |
| R.nn.conv2d( |
| data, |
| weight, |
| out_dtype=dtype, |
| padding=(1, 1), |
| data_layout="NHWC", |
| kernel_layout="OHWI", |
| ) |
| ) |
| output = R.emit(output + bias) |
| output = R.emit(R.nn.relu(output)) |
| output = R.emit(R.add(output, residual)) |
| R.output(output) |
| |
| R.func_ret_value(frame.output_vars[0]) |
| |
| func = builder.get() |
| return tvm.IRModule({"main": func}) |
| |
| low = -1 |
| high = 1 |
| |
| residual_batch = 1 |
| mod = get_mod(residual_batch) |
| data = np.random.randint(low, high, size=data_shape).astype(dtype) |
| weight = np.random.randint(low, high, size=weight_shape).astype(dtype) |
| bias = np.random.randint(low, high, size=(1, 1, weight_shape[0])).astype(dtype) |
| bias2 = np.random.randint(low, high, size=(residual_batch, 1, 1, weight_shape[0])).astype(dtype) |
| |
| args = [data, weight, bias, bias2] |
| out = get_result_with_relax_cutlass_offload(mod, *args) |
| ref = build_and_run(mod, args, "llvm") |
| tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) |
| |
| |
| @pytest.mark.parametrize( |
| "data_shape, dtype, axes", |
| [ |
| ((2, 128, 64), "float16", [-1]), |
| ((128, 30), "float32", [-1]), |
| ((2, 128, 64), "float32", [1]), |
| ((2, 128, 64), "float32", [1, 2]), |
| ], |
| ) |
| def test_layer_norm(data_shape, dtype, axes): |
| def get_mod(data_shape, dtype, axes): |
| reduced_shape = [data_shape[axis] for axis in axes] |
| with IRBuilder() as builder: |
| with relax_builder.function(): |
| R.func_name("main") |
| inp = R.arg("input", R.Tensor(data_shape, dtype)) |
| gamma = R.arg("gamma", R.Tensor(reduced_shape, dtype)) |
| beta = R.arg("beta", R.Tensor(reduced_shape, dtype)) |
| |
| with R.dataflow() as frame: |
| output = R.emit(R.nn.layer_norm(inp, gamma, beta, axes)) |
| R.output(output) |
| |
| R.func_ret_value(frame.output_vars[0]) |
| |
| func = builder.get() |
| return tvm.IRModule({"main": func}) |
| |
| Module = get_mod(data_shape, dtype, axes) |
| mod = partition_for_cutlass(Module) |
| |
| if len(axes) != 1 or (axes[0] != -1 and axes[0] != len(data_shape) - 1): |
| tvm.ir.assert_structural_equal(mod, Module) |
| return |
| |
| mod = relax.transform.RunCodegen()(mod) |
| |
| inp = np.random.randn(*data_shape).astype(dtype) |
| gamma = np.random.randn(data_shape[-1]).astype(dtype) |
| beta = np.random.randn(data_shape[-1]).astype(dtype) |
| out = build_and_run(mod, [inp, gamma, beta], "cuda") |
| ref = build_and_run(Module, [inp, gamma, beta], "llvm") |
| |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| def test_attention_rewrite_fp16(): |
| @I.ir_module |
| class Module: |
| @R.function |
| def main( |
| q: R.Tensor((4, 16, 32, 8), dtype="float16"), |
| k: R.Tensor((4, 8, 32, 8), dtype="float16"), |
| v: R.Tensor((4, 8, 32, 16), dtype="float16"), |
| bias: R.Tensor((4, 32, 16, 8), dtype="float16"), |
| ) -> R.Tensor((4, 16, 32, 16), dtype="float16"): |
| R.func_attr({"num_input": 4}) |
| with R.dataflow(): |
| lv = R.permute_dims(q, axes=[0, 2, 1, 3]) |
| lv1 = R.reshape(lv, R.shape([128, 16, 8])) |
| lv2 = R.permute_dims(k, axes=[0, 2, 1, 3]) |
| lv3 = R.reshape(lv2, R.shape([128, 8, 8])) |
| lv4 = R.permute_dims(v, axes=[0, 2, 1, 3]) |
| lv5 = R.reshape(lv4, R.shape([128, 8, 16])) |
| lv6 = R.permute_dims(lv3, axes=[0, 2, 1]) |
| lv7 = R.matmul(lv1, lv6, out_dtype="float16") |
| lv3_1 = R.astype(R.const(0.5, "float32"), dtype="float16") |
| lv8 = R.multiply(lv7, lv3_1) |
| lv9 = R.reshape(bias, R.shape([128, 16, 8])) |
| lv10 = R.add(lv8, lv9) |
| lv10_fp16 = R.astype(lv10, dtype="float16") |
| lv11 = R.nn.softmax(lv10_fp16, axis=2) |
| lv5_1 = R.astype(lv11, dtype="float16") |
| lv12 = R.matmul(lv5_1, lv5, out_dtype="float16") |
| lv13 = R.reshape(lv12, R.shape([4, 32, 16, 16])) |
| lv6_1 = R.permute_dims(lv13, axes=[0, 2, 1, 3]) |
| lv14 = R.astype(lv6_1, dtype="float32") |
| R.output(lv14) |
| return lv14 |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def fused_relax_nn_attention_bias_cutlass1( |
| q: R.Tensor((4, 16, 32, 8), dtype="float16"), |
| k: R.Tensor((4, 8, 32, 8), dtype="float16"), |
| v: R.Tensor((4, 8, 32, 16), dtype="float16"), |
| lv1: R.Tensor((4, 32, 16, 8), dtype="float16"), |
| workspace: R.Tensor((65536,), dtype="uint8"), |
| ) -> R.Tensor((4, 16, 32, 16), dtype="float16"): |
| R.func_attr( |
| { |
| "Codegen": "cutlass", |
| "WorkspaceSize": T.int64(65536), |
| "global_symbol": "fused_relax_nn_attention_bias_cutlass1", |
| } |
| ) |
| |
| @R.function |
| def gv_1( |
| q_1: R.Tensor((4, 16, 32, 8), dtype="float16"), |
| k_1: R.Tensor((4, 8, 32, 8), dtype="float16"), |
| v_1: R.Tensor((4, 8, 32, 16), dtype="float16"), |
| lv1_1: R.Tensor((4, 32, 16, 8), dtype="float16"), |
| workspace_1: R.Tensor((65536,), dtype="uint8"), |
| ) -> R.Tensor((4, 16, 32, 16), dtype="float16"): |
| R.func_attr( |
| { |
| "Composite": "cutlass.attention_bias", |
| "WorkspaceSize": T.int64(65536), |
| } |
| ) |
| with R.dataflow(): |
| gv_2 = R.nn.attention( |
| q_1, k_1, v_1, lv1_1, scale=T.float32(0.5), causal_mask=None |
| ) |
| R.output(gv_2) |
| return gv_2 |
| |
| gv1: R.Tensor((4, 16, 32, 16), dtype="float16") = gv_1(q, k, v, lv1, workspace) |
| return gv1 |
| |
| @R.function |
| def main( |
| q: R.Tensor((4, 16, 32, 8), dtype="float16"), |
| k: R.Tensor((4, 8, 32, 8), dtype="float16"), |
| v: R.Tensor((4, 8, 32, 16), dtype="float16"), |
| bias: R.Tensor((4, 32, 16, 8), dtype="float16"), |
| ) -> R.Tensor((4, 16, 32, 16), dtype="float32"): |
| R.func_attr({"num_input": 4}) |
| cls = Expected |
| with R.dataflow(): |
| lv = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), R.dtype("uint8")) |
| workspace_main = R.vm.alloc_tensor( |
| lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8") |
| ) |
| lv_1 = R.reshape(bias, R.shape([128, 16, 8])) |
| lv1 = R.reshape(lv_1, R.shape([4, 32, 16, 8])) |
| lv_2 = cls.fused_relax_nn_attention_bias_cutlass1(q, k, v, lv1, workspace_main) |
| lv14 = R.astype(lv_2, dtype="float32") |
| R.output(lv14) |
| return lv14 |
| |
| mod = partition_for_cutlass(Module) |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def split_transform_deploy_mod(mod): |
| mod_transform = tvm.IRModule() |
| mod_deploy = tvm.IRModule().with_attrs(mod.attrs) |
| |
| transform_func_name = None |
| |
| for gv, func in mod.functions.items(): |
| if "transform_params" in gv.name_hint: |
| transform_func_name = gv.name_hint |
| mod_transform[gv] = func |
| elif isinstance(func, tvm.tir.PrimFunc): |
| mod_transform[gv] = func |
| else: |
| mod_deploy[gv] = func |
| |
| assert transform_func_name is not None |
| return mod_transform, mod_deploy, transform_func_name |
| |
| |
| def test_fp16A_int4B_gemm(): |
| @I.ir_module |
| class Module: |
| @T.prim_func |
| def decode( |
| A: T.Buffer((T.int64(64), T.int64(64)), "int8"), |
| B: T.Buffer((T.int64(128),), "float16"), |
| decode_1: T.Buffer((T.int64(64), T.int64(128)), "float16"), |
| ): |
| T.func_attr({"tir.noalias": T.bool(True)}) |
| # with T.block("root"): |
| for i, j in T.grid(T.int64(64), T.int64(128)): |
| with T.block("decode"): |
| v_i, v_j = T.axis.remap("SS", [i, j]) |
| T.reads(A[v_i, v_j // T.int64(2)], B[v_j]) |
| T.writes(decode_1[v_i, v_j]) |
| decode_1[v_i, v_j] = ( |
| T.Cast( |
| "float16", |
| T.shift_right( |
| T.shift_left( |
| T.bitwise_and( |
| T.shift_right( |
| T.Cast("int32", A[v_i, v_j // T.int64(2)]), |
| T.Cast("int32", v_j % T.int64(2)) * 4, |
| ), |
| 15, |
| ), |
| 28, |
| ), |
| 28, |
| ), |
| ) |
| * B[v_j] |
| ) |
| |
| @T.prim_func |
| def encode( |
| A: T.Buffer((T.int64(128), T.int64(64)), "float16"), |
| w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"), |
| compute: T.Buffer((T.int64(128),), "float16"), |
| ): |
| T.func_attr({"tir.noalias": T.bool(True)}) |
| # with T.block("root"): |
| max_abs_value = T.alloc_buffer((T.int64(128),), "float16") |
| scale = T.alloc_buffer((T.int64(128),)) |
| for i, k in T.grid(T.int64(128), T.int64(64)): |
| with T.block("max_abs_value"): |
| v_i, v_k = T.axis.remap("SR", [i, k]) |
| T.reads(A[v_i, v_k]) |
| T.writes(max_abs_value[v_i]) |
| with T.init(): |
| max_abs_value[v_i] = T.float16(-65504) |
| max_abs_value[v_i] = T.max(max_abs_value[v_i], T.fabs(A[v_i, v_k])) |
| for i in range(T.int64(128)): |
| with T.block("scale"): |
| v_i = T.axis.spatial(T.int64(128), i) |
| T.reads(max_abs_value[v_i]) |
| T.writes(scale[v_i]) |
| scale[v_i] = T.max( |
| T.Cast("float32", max_abs_value[v_i]), T.float32(0.0001) |
| ) * T.float32(0.125) |
| for j, i, k in T.grid(T.int64(64), T.int64(64), T.int64(2)): |
| with T.block("w_gathered"): |
| v_j, v_i, v_k = T.axis.remap("SSR", [j, i, k]) |
| T.reads(A[v_i * T.int64(2) + v_k, v_j], scale[v_i * T.int64(2) + v_k]) |
| T.writes(w_gathered[v_j, v_i]) |
| with T.init(): |
| w_gathered[v_j, v_i] = T.int8(0) |
| w_gathered[v_j, v_i] = T.bitwise_or( |
| w_gathered[v_j, v_i], |
| T.if_then_else( |
| v_i * T.int64(2) + v_k < T.int64(128), |
| T.shift_left( |
| T.bitwise_and( |
| T.Cast( |
| "int8", |
| T.min( |
| T.max( |
| T.round( |
| T.Cast( |
| "float32", A[v_i * T.int64(2) + v_k, v_j] |
| ) |
| / scale[v_i * T.int64(2) + v_k] |
| ), |
| T.float32(-8), |
| ), |
| T.float32(7), |
| ), |
| ), |
| T.int8(15), |
| ), |
| T.Cast("int8", v_k) * T.int8(4), |
| ), |
| T.int8(0), |
| ), |
| ) |
| for i0 in range(T.int64(128)): |
| with T.block("compute"): |
| v_i0 = T.axis.spatial(T.int64(128), i0) |
| T.reads(scale[v_i0]) |
| T.writes(compute[v_i0]) |
| compute[v_i0] = T.Cast("float16", scale[v_i0]) |
| |
| @R.function |
| def main_bias( |
| x: R.Tensor((64, 64), dtype="float16"), |
| y: R.Tensor((128, 64), dtype="float16"), |
| bias: R.Tensor((1, 128), dtype="float16"), |
| ) -> R.Tensor((64, 128), dtype="float16"): |
| R.func_attr({"num_input": 1}) |
| cls = Module |
| with R.dataflow(): |
| lv = R.call_tir( |
| cls.encode, |
| (y,), |
| out_sinfo=[R.Tensor((64, 64), dtype="int8"), R.Tensor((128,), dtype="float16")], |
| ) |
| lv1 = lv[0] |
| lv2 = R.call_pure_packed( |
| "cutlass.ft_preprocess_weight", |
| lv1, |
| 80, |
| True, |
| sinfo_args=(R.Tensor((64, 64), dtype="int8"),), |
| ) |
| lv3: R.Tensor((128,), dtype="float16") = lv[1] |
| lv6 = R.call_tir( |
| cls.decode, (lv2, lv3), out_sinfo=R.Tensor((64, 128), dtype="float16") |
| ) |
| lv1_1: R.Tensor((64, 128), dtype="float16") = R.matmul(x, lv6, out_dtype="float16") |
| lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(lv1_1, bias) |
| R.output(lv2_1) |
| return lv2_1 |
| |
| @R.function |
| def main_cast_bias( |
| x: R.Tensor((64, 64), dtype="float16"), |
| y: R.Tensor((128, 64), dtype="float16"), |
| bias: R.Tensor((1, 128), dtype="float16"), |
| ) -> R.Tensor((64, 128), dtype="float16"): |
| R.func_attr({"num_input": 1}) |
| cls = Module |
| with R.dataflow(): |
| lv = R.call_tir( |
| cls.encode, |
| (y,), |
| out_sinfo=[R.Tensor((64, 64), dtype="int8"), R.Tensor((128,), dtype="float16")], |
| ) |
| lv1 = lv[0] |
| lv2 = R.call_pure_packed( |
| "cutlass.ft_preprocess_weight", |
| lv1, |
| 80, |
| True, |
| sinfo_args=(R.Tensor((64, 64), dtype="int8"),), |
| ) |
| lv3: R.Tensor((128,), dtype="float16") = lv[1] |
| lv6 = R.call_tir( |
| cls.decode, (lv2, lv3), out_sinfo=R.Tensor((64, 128), dtype="float16") |
| ) |
| lv1_1: R.Tensor((64, 128), dtype="float32") = R.matmul(x, lv6, out_dtype="float32") |
| cast: R.Tensor((64, 128), dtype="float16") = R.astype(lv1_1, dtype="float16") |
| lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(cast, bias) |
| R.output(lv2_1) |
| return lv2_1 |
| |
| @R.function |
| def main_residual( |
| x: R.Tensor((64, 64), dtype="float16"), |
| residual: R.Tensor((64, 128), dtype="float16"), |
| y: R.Tensor((128, 64), dtype="float16"), |
| bias: R.Tensor((1, 128), dtype="float16"), |
| ) -> R.Tensor((64, 128), dtype="float16"): |
| R.func_attr({"num_input": 2}) |
| cls = Module |
| with R.dataflow(): |
| lv = R.call_tir( |
| cls.encode, |
| (y,), |
| out_sinfo=[R.Tensor((64, 64), dtype="int8"), R.Tensor((128,), dtype="float16")], |
| ) |
| lv1 = lv[0] |
| lv2 = R.call_pure_packed( |
| "cutlass.ft_preprocess_weight", |
| lv1, |
| 80, |
| True, |
| sinfo_args=(R.Tensor((64, 64), dtype="int8"),), |
| ) |
| lv3: R.Tensor((128,), dtype="float16") = lv[1] |
| lv6 = R.call_tir( |
| cls.decode, (lv2, lv3), out_sinfo=R.Tensor((64, 128), dtype="float16") |
| ) |
| lv1_1: R.Tensor((64, 128), dtype="float16") = R.matmul(x, lv6, out_dtype="float16") |
| lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(lv1_1, bias) |
| lv3_1: R.Tensor((64, 128), dtype="float16") = R.add(lv2_1, residual) |
| R.output(lv3_1) |
| return lv3_1 |
| |
| x_shape = (64, 64) |
| y_shape = (128, 64) |
| |
| mod = partition_for_cutlass(Module) |
| func_names = [name.name_hint for (name, _) in mod.functions.items()] |
| assert "fused_decode_relax_matmul_relax_add_cutlass" in func_names |
| assert "fused_decode_relax_matmul_relax_add_relax_add_cutlass" in func_names |
| assert "fused_decode_relax_matmul_relax_astype_relax_add_cutlass" in func_names |
| |
| mod = relax.transform.RunCodegen( |
| {"cutlass": {"sm": 80, "find_first_valid": False}}, |
| entry_functions=["main_bias", "main_residual", "main_cast_bias"], |
| )(mod) |
| |
| x = np.random.randn(*x_shape).astype("float16") |
| y = np.random.normal(0, 0.002, size=y_shape).astype("float16") |
| bias = np.random.randn(1, y_shape[0]).astype("float16") |
| residual = np.random.randn(x_shape[0], y_shape[0]).astype("float16") |
| |
| mod = relax.pipeline.get_pipeline()(mod) |
| mod = relax.transform.LiftTransformParams()(mod) |
| |
| mod_transform, mod_deploy, transform_func_name = split_transform_deploy_mod(mod) |
| |
| ex = relax.build(mod_transform, target="llvm") |
| vm = relax.vm.VirtualMachine(ex, tvm.cpu(0)) |
| |
| packed_weight, scales, bias_trans = vm[transform_func_name]( |
| (tvm.nd.array(y), tvm.nd.array(bias)) |
| ) |
| |
| dev = tvm.device("cuda", 0) |
| ex = relax.build(mod_deploy, target="cuda") |
| vm = relax.vm.VirtualMachine(ex, dev) |
| |
| x_nd = tvm.nd.array(x, dev) |
| residual_nd = tvm.nd.array(residual, dev) |
| params = [packed_weight.copyto(dev), scales.copyto(dev), bias_trans.copyto(dev)] |
| |
| for f_name in ["main_bias", "main_cast_bias", "main_residual"]: |
| with_residual = "residual" in f_name |
| |
| if with_residual: |
| inp = [x_nd, residual_nd] + params |
| else: |
| inp = [x_nd] + params |
| |
| out = vm[f_name](*inp).numpy() |
| |
| ref = np.dot(x, y.transpose()) + bias |
| |
| if with_residual: |
| ref += residual |
| |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| def test_fp16A_int8B_gemm(): |
| @I.ir_module |
| class Module: |
| @T.prim_func |
| def decode( |
| A: T.Buffer((T.int64(64), T.int64(64)), "int8"), |
| B: T.Buffer((T.int64(64),), "float16"), |
| decode_1: T.Buffer((T.int64(64), T.int64(64)), "float16"), |
| ): |
| T.func_attr({"tir.noalias": T.bool(True)}) |
| # with T.block("root"): |
| for i, j in T.grid(T.int64(64), T.int64(64)): |
| with T.block("decode"): |
| v_i, v_j = T.axis.remap("SS", [i, j]) |
| T.reads(A[v_i, v_j], B[v_j]) |
| T.writes(decode_1[v_i, v_j]) |
| decode_1[v_i, v_j] = T.Cast("float16", A[v_i, v_j]) * B[v_j] |
| |
| @T.prim_func |
| def encode( |
| A: T.Buffer((T.int64(64), T.int64(64)), "float16"), |
| w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"), |
| compute: T.Buffer((T.int64(64),), "float16"), |
| ): |
| T.func_attr({"tir.noalias": T.bool(True)}) |
| # with T.block("root"): |
| max_abs_value = T.alloc_buffer((T.int64(64),), "float16") |
| scale = T.alloc_buffer((T.int64(64),)) |
| for i, k in T.grid(T.int64(64), T.int64(64)): |
| with T.block("max_abs_value"): |
| v_i, v_k = T.axis.remap("SR", [i, k]) |
| T.reads(A[v_i, v_k]) |
| T.writes(max_abs_value[v_i]) |
| with T.init(): |
| max_abs_value[v_i] = T.float16(-65504) |
| max_abs_value[v_i] = T.max(max_abs_value[v_i], T.fabs(A[v_i, v_k])) |
| for i in range(T.int64(64)): |
| with T.block("scale"): |
| v_i = T.axis.spatial(T.int64(64), i) |
| T.reads(max_abs_value[v_i]) |
| T.writes(scale[v_i]) |
| scale[v_i] = T.max( |
| T.Cast("float32", max_abs_value[v_i]), T.float32(0.0001) |
| ) * T.float32(0.0078125) |
| for j, i in T.grid(T.int64(64), T.int64(64)): |
| with T.block("w_gathered"): |
| v_j, v_i = T.axis.remap("SS", [j, i]) |
| T.reads(A[v_i, v_j], scale[v_i]) |
| T.writes(w_gathered[v_j, v_i]) |
| w_gathered[v_j, v_i] = T.Cast( |
| "int8", |
| T.min( |
| T.max( |
| T.round(T.Cast("float32", A[v_i, v_j]) / scale[v_i]), |
| T.float32(-128), |
| ), |
| T.float32(127), |
| ), |
| ) |
| for i0 in range(T.int64(64)): |
| with T.block("compute"): |
| v_i0 = T.axis.spatial(T.int64(64), i0) |
| T.reads(scale[v_i0]) |
| T.writes(compute[v_i0]) |
| compute[v_i0] = T.Cast("float16", scale[v_i0]) |
| |
| @R.function |
| def main( |
| x: R.Tensor((64, 64), dtype="float16"), |
| y: R.Tensor((64, 64), dtype="float16"), |
| bias: R.Tensor((64, 64), dtype="float16"), |
| ) -> R.Tensor((64, 64), dtype="float16"): |
| R.func_attr({"num_input": 1}) |
| cls = Module |
| with R.dataflow(): |
| lv = R.call_tir( |
| cls.encode, |
| (y,), |
| out_sinfo=[R.Tensor((64, 64), dtype="int8"), R.Tensor((64,), dtype="float16")], |
| ) |
| lv1: R.Tensor((64, 64), dtype="int8") = lv[0] |
| lv2: R.Tensor((64, 64), dtype="int8") = R.call_pure_packed( |
| "cutlass.ft_preprocess_weight", |
| lv1, |
| R.prim_value(80), |
| R.prim_value(0), |
| sinfo_args=(R.Tensor((64, 64), dtype="int8"),), |
| ) |
| lv3: R.Tensor((64,), dtype="float16") = lv[1] |
| lv4: R.Tensor((64, 64), dtype="int8") = R.builtin.stop_lift_params(lv2) |
| lv5: R.Tensor((64,), dtype="float16") = R.builtin.stop_lift_params(lv3) |
| lv6 = R.call_tir( |
| cls.decode, (lv4, lv5), out_sinfo=R.Tensor((64, 64), dtype="float16") |
| ) |
| lv1_1: R.Tensor((64, 64), dtype="float16") = R.matmul(x, lv6, out_dtype="float16") |
| lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(lv1_1, bias) |
| lv2_2: R.Tensor((64, 128), dtype="float16") = R.nn.gelu(lv2_1) |
| R.output(lv2_2) |
| return lv2_2 |
| |
| x_shape = (64, 64) |
| y_shape = (64, 64) |
| |
| mod = partition_for_cutlass(Module) |
| func_names = [name.name_hint for (name, _) in mod.functions.items()] |
| assert "fused_decode_relax_matmul_relax_add_relax_nn_gelu_cutlass" in func_names |
| |
| mod = relax.transform.RunCodegen( |
| {"cutlass": {"sm": 80, "find_first_valid": False}}, |
| )(mod) |
| |
| x = np.random.randn(*x_shape).astype("float16") |
| y = np.random.normal(0, 0.002, size=y_shape).astype("float16") |
| bias = np.random.randn(x_shape[0], y_shape[0]).astype("float16") |
| |
| mod = relax.pipeline.get_pipeline()(mod) |
| mod = relax.transform.LiftTransformParams()(mod) |
| |
| mod_transform, mod_deploy, transform_func_name = split_transform_deploy_mod(mod) |
| |
| ex = relax.build(mod_transform, target="llvm") |
| vm = relax.vm.VirtualMachine(ex, tvm.cpu(0)) |
| |
| packed_weight, scales, bias_trans = vm[transform_func_name]( |
| (tvm.nd.array(y), tvm.nd.array(bias)) |
| ) |
| |
| dev = tvm.device("cuda", 0) |
| ex = relax.build(mod_deploy, target="cuda") |
| vm = relax.vm.VirtualMachine(ex, dev) |
| |
| x_nd = tvm.nd.array(x, dev) |
| inp = [x_nd, packed_weight.copyto(dev), scales.copyto(dev), bias_trans.copyto(dev)] |
| out = vm["main"](*inp).numpy() |
| |
| def gelu_fp16(x): |
| erf_inp = x * (0.5**0.5) |
| from scipy.special import erf |
| |
| erf_out = erf(erf_inp.astype("float32")).astype("float16") |
| return x * 0.5 * (1.0 + erf_out) |
| |
| ref = gelu_fp16(np.dot(x, y.transpose()) + bias) |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| def test_rms_norm(): |
| @I.ir_module |
| class Module: |
| @T.prim_func |
| def rms_norm( |
| A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), |
| B: T.Buffer((T.int64(4096),), "float16"), |
| rms_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), |
| ): |
| T.func_attr({"tir.noalias": T.bool(True)}) |
| # with T.block("root"): |
| Ared_temp = T.alloc_buffer((T.int64(1), T.int64(1))) |
| for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)): |
| with T.block("Ared_temp"): |
| v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) |
| T.reads(A[v_bsz, v_i, v_k]) |
| T.writes(Ared_temp[v_bsz, v_i]) |
| with T.init(): |
| Ared_temp[v_bsz, v_i] = T.float32(0) |
| Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast( |
| "float32", A[v_bsz, v_i, v_k] |
| ) * T.Cast("float32", A[v_bsz, v_i, v_k]) |
| for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)): |
| with T.block("rms_norm"): |
| v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) |
| T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i]) |
| T.writes(rms_norm[v_bsz, v_i, v_k]) |
| rms_norm[v_bsz, v_i, v_k] = T.Cast( |
| "float16", |
| T.Cast("float32", B[v_k]) |
| * ( |
| T.Cast("float32", A[v_bsz, v_i, v_k]) |
| / T.sqrt( |
| Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) |
| + T.float32(9.9999999999999995e-07) |
| ) |
| ), |
| ) |
| |
| @R.function |
| def main( |
| input: R.Tensor((1, 1, 4096), dtype="float16"), |
| weight: R.Tensor((4096,), dtype="float16"), |
| ) -> R.Tensor((1, 1, 4096), dtype="float16"): |
| cls = Module |
| with R.dataflow(): |
| lv = R.call_tir( |
| cls.rms_norm, (input, weight), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16") |
| ) |
| R.output(lv) |
| return lv |
| |
| data_shape = (1, 1, 4096) |
| dtype = "float16" |
| mod = partition_for_cutlass(Module) |
| |
| # TODO(@tvm-team): This is temporary patch.Currently, the remaining packed function triggers error since it is not scheduled. |
| # This is because RunCodegen does not support PrimFunc well yet. |
| # i.e., it does remove the global symbol of PrimFunc, which would be no longer used, |
| # and thus, the following DCE cannot remove this. Revisit when resolved. |
| with tvm.target.Target("cuda"): |
| mod = tvm.tir.transform.DefaultGPUSchedule()(mod) |
| |
| mod = relax.transform.RunCodegen( |
| {"cutlass": {"rms_eps": 1e-6}}, |
| )(mod) |
| |
| inp = np.random.randn(*data_shape).astype(dtype) |
| weight = np.random.randn(data_shape[-1]).astype(dtype) |
| out = build_and_run(mod, [inp, weight], "cuda") |
| ref = build_and_run(Module, [inp, weight], "llvm", legalize=True) |
| |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| def test_conv2d_cuda_graph(): |
| @tvm.script.ir_module |
| class Conv2d: |
| @R.function |
| def main( |
| data: R.Tensor((16, 32, 32, 16), "float16"), |
| weight1: R.Tensor((16, 3, 3, 16), "float16"), |
| weight2: R.Tensor((16, 3, 3, 16), "float16"), |
| weight3: R.Tensor((16, 3, 3, 16), "float16"), |
| gamma: R.Tensor((16,), "float16"), |
| beta: R.Tensor((16,), "float16"), |
| ): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| conv1 = R.nn.relu( |
| R.nn.conv2d( |
| data, weight1, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" |
| ) |
| ) |
| conv2 = R.nn.relu( |
| R.nn.conv2d( |
| conv1, weight2, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" |
| ) |
| ) |
| ln = R.nn.layer_norm(conv2, gamma, beta, axes=[-1]) |
| conv3 = R.nn.relu( |
| R.nn.conv2d( |
| ln, weight3, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" |
| ) |
| ) |
| R.output(conv3) |
| |
| return conv3 |
| |
| low, high = -1, 1 |
| data_shape = (16, 32, 32, 16) |
| weight_shape = (16, 3, 3, 16) |
| dtype = "float16" |
| data = np.random.randint(low, high, size=data_shape).astype(dtype) |
| weight1 = np.random.randint(low, high, size=weight_shape).astype(dtype) |
| weight2 = np.random.randint(low, high, size=weight_shape).astype(dtype) |
| weight3 = np.random.randint(low, high, size=weight_shape).astype(dtype) |
| gamma = np.random.randint(low, high, size=(weight_shape[0],)).astype(dtype) |
| beta = np.random.randint(low, high, size=(weight_shape[0],)).astype(dtype) |
| inputs = [data, weight1, weight2, weight3, gamma, beta] |
| |
| mod = partition_for_cutlass(Conv2d) |
| mod = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}})(mod) |
| mod = relax.pipeline.get_pipeline()(mod) # pylint: disable=no-value-for-parameter |
| |
| with tvm.target.Target("cuda"): |
| mod = tvm.tir.transform.DefaultGPUSchedule()(mod) |
| |
| out = build_and_run(mod, inputs, "cuda", cuda_graph=True) |
| ref = build_and_run(Conv2d, inputs, "llvm", legalize=True) |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| def test_fp16A_int8B_gemm_batched(): |
| @I.ir_module |
| class Module: |
| @T.prim_func |
| def decode( |
| A: T.Buffer((T.int64(64), T.int64(64)), "int8"), |
| B: T.Buffer((T.int64(64),), "float16"), |
| decode_1: T.Buffer((T.int64(64), T.int64(64)), "float16"), |
| ): |
| T.func_attr({"tir.noalias": T.bool(True)}) |
| # with T.block("root"): |
| for i, j in T.grid(T.int64(64), T.int64(64)): |
| with T.block("decode"): |
| v_i, v_j = T.axis.remap("SS", [i, j]) |
| T.reads(A[v_i, v_j], B[v_j]) |
| T.writes(decode_1[v_i, v_j]) |
| decode_1[v_i, v_j] = T.Cast("float16", A[v_i, v_j]) * B[v_j] |
| |
| @T.prim_func |
| def encode( |
| A: T.Buffer((T.int64(64), T.int64(64)), "float16"), |
| w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"), |
| compute: T.Buffer((T.int64(64),), "float16"), |
| ): |
| T.func_attr({"tir.noalias": T.bool(True)}) |
| # with T.block("root"): |
| max_abs_value = T.alloc_buffer((T.int64(64),), "float16") |
| scale = T.alloc_buffer((T.int64(64),)) |
| for i, k in T.grid(T.int64(64), T.int64(64)): |
| with T.block("max_abs_value"): |
| v_i, v_k = T.axis.remap("SR", [i, k]) |
| T.reads(A[v_i, v_k]) |
| T.writes(max_abs_value[v_i]) |
| with T.init(): |
| max_abs_value[v_i] = T.float16(-65504) |
| max_abs_value[v_i] = T.max(max_abs_value[v_i], T.fabs(A[v_i, v_k])) |
| for i in range(T.int64(64)): |
| with T.block("scale"): |
| v_i = T.axis.spatial(T.int64(64), i) |
| T.reads(max_abs_value[v_i]) |
| T.writes(scale[v_i]) |
| scale[v_i] = T.max( |
| T.Cast("float32", max_abs_value[v_i]), T.float32(0.0001) |
| ) * T.float32(0.0078125) |
| for j, i in T.grid(T.int64(64), T.int64(64)): |
| with T.block("w_gathered"): |
| v_j, v_i = T.axis.remap("SS", [j, i]) |
| T.reads(A[v_i, v_j], scale[v_i]) |
| T.writes(w_gathered[v_j, v_i]) |
| w_gathered[v_j, v_i] = T.Cast( |
| "int8", |
| T.min( |
| T.max( |
| T.round(T.Cast("float32", A[v_i, v_j]) / scale[v_i]), |
| T.float32(-128), |
| ), |
| T.float32(127), |
| ), |
| ) |
| for i0 in range(T.int64(64)): |
| with T.block("compute"): |
| v_i0 = T.axis.spatial(T.int64(64), i0) |
| T.reads(scale[v_i0]) |
| T.writes(compute[v_i0]) |
| compute[v_i0] = T.Cast("float16", scale[v_i0]) |
| |
| @R.function |
| def main( |
| x: R.Tensor(("b", 64, 64), dtype="float16"), |
| y: R.Tensor((64, 64), dtype="float16"), |
| ) -> R.Tensor(("b", 64, 64), dtype="float16"): |
| R.func_attr({"num_input": 1}) |
| cls = Module |
| b = T.int64() |
| with R.dataflow(): |
| lv = R.call_tir( |
| cls.encode, |
| (y,), |
| out_sinfo=[R.Tensor((64, 64), dtype="int8"), R.Tensor((64,), dtype="float16")], |
| ) |
| lv1: R.Tensor((64, 64), dtype="int8") = lv[0] |
| lv2: R.Tensor((64, 64), dtype="int8") = R.call_pure_packed( |
| "cutlass.ft_preprocess_weight", |
| lv1, |
| R.prim_value(80), |
| R.prim_value(0), |
| sinfo_args=(R.Tensor((64, 64), dtype="int8"),), |
| ) |
| lv3: R.Tensor((64,), dtype="float16") = lv[1] |
| lv4: R.Tensor((64, 64), dtype="int8") = R.builtin.stop_lift_params(lv2) |
| lv5: R.Tensor((64,), dtype="float16") = R.builtin.stop_lift_params(lv3) |
| lv6 = R.call_tir( |
| cls.decode, (lv4, lv5), out_sinfo=R.Tensor((64, 64), dtype="float16") |
| ) |
| lv1_1: R.Tensor((b, 64, 64), dtype="float16") = R.matmul( |
| x, lv6, out_dtype="float16" |
| ) |
| R.output(lv1_1) |
| return lv1_1 |
| |
| x_shape = (4, 64, 64) |
| y_shape = (64, 64) |
| |
| mod = partition_for_cutlass(Module) |
| |
| mod = relax.transform.RunCodegen( |
| {"cutlass": {"sm": 80, "find_first_valid": False}}, |
| )(mod) |
| |
| x = np.random.randn(*x_shape).astype("float16") |
| y = np.random.normal(0, 0.002, size=y_shape).astype("float16") |
| |
| mod = relax.pipeline.get_pipeline()(mod) |
| mod = relax.transform.LiftTransformParams()(mod) |
| |
| mod_transform, mod_deploy, transform_func_name = split_transform_deploy_mod(mod) |
| |
| ex = relax.build(mod_transform, target="llvm") |
| vm = relax.vm.VirtualMachine(ex, tvm.cpu(0)) |
| |
| (packed_weight, scales,) = vm[ |
| transform_func_name |
| ]((tvm.nd.array(y),)) |
| |
| dev = tvm.device("cuda", 0) |
| ex = relax.build(mod_deploy, target="cuda") |
| vm = relax.vm.VirtualMachine(ex, dev) |
| |
| x_nd = tvm.nd.array(x, dev) |
| inp = [x_nd, packed_weight.copyto(dev), scales.copyto(dev)] |
| out = vm["main"](*inp).numpy() |
| ref = np.dot(x, y.transpose()) |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| def test_fp16A_int8B_gemm_batched_finegrained(): |
| @I.ir_module |
| class Module: |
| @T.prim_func |
| def decode( |
| A: T.Buffer((T.int64(128), T.int64(128)), "int8"), |
| B: T.Buffer((T.int64(2), T.int64(128)), "float16"), |
| decode_1: T.Buffer((T.int64(128), T.int64(128)), "float16"), |
| ): |
| T.func_attr({"tir.noalias": T.bool(True)}) |
| for i, j in T.grid(T.int64(128), T.int64(128)): |
| with T.block("decode"): |
| v_i, v_j = T.axis.remap("SS", [i, j]) |
| T.reads(A[v_i, v_j], B[v_i // T.int64(64), v_j]) |
| T.writes(decode_1[v_i, v_j]) |
| decode_1[v_i, v_j] = T.Cast("float16", A[v_i, v_j]) * B[v_i // T.int64(64), v_j] |
| |
| @T.prim_func |
| def encode( |
| A: T.Buffer((T.int64(128), T.int64(128)), "float16"), |
| w_gathered: T.Buffer((T.int64(128), T.int64(128)), "int8"), |
| compute: T.Buffer( |
| ( |
| T.int64(2), |
| T.int64(128), |
| ), |
| "float16", |
| ), |
| ): |
| T.func_attr({"tir.noalias": T.bool(True)}) |
| max_abs_value = T.alloc_buffer( |
| ( |
| T.int64(2), |
| T.int64(128), |
| ), |
| "float16", |
| ) |
| scale = T.alloc_buffer( |
| ( |
| T.int64(2), |
| T.int64(128), |
| ) |
| ) |
| for i, j, k in T.grid(T.int64(2), T.int64(128), T.int64(64)): |
| with T.block("max_abs_value"): |
| v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) |
| T.reads(A[v_j, v_i * T.int64(64) + v_k]) |
| T.writes(max_abs_value[v_i, v_j]) |
| with T.init(): |
| max_abs_value[v_i, v_j] = T.float16(-65504) |
| max_abs_value[v_i, v_j] = T.max( |
| max_abs_value[v_i, v_j], T.fabs(A[v_j, v_i * T.int64(64) + v_k]) |
| ) |
| for i, j in T.grid(T.int64(2), T.int64(128)): |
| with T.block("scale"): |
| v_i, v_j = T.axis.remap("SS", [i, j]) |
| T.reads(max_abs_value[v_i, v_j]) |
| T.writes(scale[v_i, v_j]) |
| scale[v_i, v_j] = T.max( |
| T.Cast("float32", max_abs_value[v_i, v_j]), T.float32(0.0001) |
| ) * T.float32(0.0078125) |
| for j, i in T.grid(T.int64(128), T.int64(128)): |
| with T.block("w_gathered"): |
| v_j, v_i = T.axis.remap("SS", [j, i]) |
| T.reads(A[v_i, v_j], scale[v_j // T.int64(64), v_i]) |
| T.writes(w_gathered[v_j, v_i]) |
| w_gathered[v_j, v_i] = T.Cast( |
| "int8", |
| T.min( |
| T.max( |
| T.round( |
| T.Cast("float32", A[v_i, v_j]) / scale[v_j // T.int64(64), v_i] |
| ), |
| T.float32(-128), |
| ), |
| T.float32(127), |
| ), |
| ) |
| for i0, i1 in T.grid(T.int64(2), T.int64(128)): |
| with T.block("compute"): |
| v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) |
| T.reads(scale[v_i0, v_i1]) |
| T.writes(compute[v_i0, v_i1]) |
| compute[v_i0, v_i1] = T.Cast("float16", scale[v_i0, v_i1]) |
| |
| @R.function |
| def main( |
| x: R.Tensor(("b", 128, 128), dtype="float16"), |
| y: R.Tensor((128, 128), dtype="float16"), |
| ) -> R.Tensor(("b", 128, 128), dtype="float16"): |
| R.func_attr({"num_input": 1}) |
| cls = Module |
| b = T.int64() |
| with R.dataflow(): |
| lv = R.call_tir( |
| cls.encode, |
| (y,), |
| out_sinfo=[ |
| R.Tensor((128, 128), dtype="int8"), |
| R.Tensor((2, 128), dtype="float16"), |
| ], |
| ) |
| lv1: R.Tensor((128, 128), dtype="int8") = lv[0] |
| lv2: R.Tensor((128, 128), dtype="int8") = R.call_pure_packed( |
| "cutlass.ft_preprocess_weight", |
| lv1, |
| R.prim_value(80), |
| R.prim_value(0), |
| sinfo_args=(R.Tensor((128, 128), dtype="int8"),), |
| ) |
| lv3: R.Tensor((2, 128), dtype="float16") = lv[1] |
| lv4: R.Tensor((128, 128), dtype="int8") = R.builtin.stop_lift_params(lv2) |
| lv5: R.Tensor((2, 128), dtype="float16") = R.builtin.stop_lift_params(lv3) |
| lv6 = R.call_tir( |
| cls.decode, (lv4, lv5), out_sinfo=R.Tensor((128, 128), dtype="float16") |
| ) |
| lv1_1: R.Tensor((b, 128, 128), dtype="float16") = R.matmul( |
| x, lv6, out_dtype="float16" |
| ) |
| R.output(lv1_1) |
| return lv1_1 |
| |
| x_shape = (4, 128, 128) |
| y_shape = (128, 128) |
| |
| mod = partition_for_cutlass(Module) |
| |
| mod = relax.transform.RunCodegen( |
| {"cutlass": {"sm": 80, "find_first_valid": False}}, |
| )(mod) |
| |
| x = np.random.randn(*x_shape).astype("float16") |
| y = np.random.normal(0, 0.002, size=y_shape).astype("float16") |
| |
| mod = relax.pipeline.get_pipeline()(mod) |
| mod = relax.transform.LiftTransformParams()(mod) |
| |
| mod_transform, mod_deploy, transform_func_name = split_transform_deploy_mod(mod) |
| |
| ex = relax.build(mod_transform, target="llvm") |
| vm = relax.vm.VirtualMachine(ex, tvm.cpu(0)) |
| |
| (packed_weight, scales,) = vm[ |
| transform_func_name |
| ]((tvm.nd.array(y),)) |
| |
| dev = tvm.device("cuda", 0) |
| ex = relax.build(mod_deploy, target="cuda") |
| vm = relax.vm.VirtualMachine(ex, dev) |
| |
| x_nd = tvm.nd.array(x, dev) |
| inp = [x_nd, packed_weight.copyto(dev), scales.copyto(dev)] |
| out = vm["main"](*inp).numpy() |
| ref = np.dot(x, y.transpose()) |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| def test_attention_rewrite_multi_query(): |
| @I.ir_module |
| class Module: |
| @R.function |
| def main( |
| q: R.Tensor((4, 16, 32, 16), dtype="float16"), |
| k_single: R.Tensor((4, 16, 1, 16), dtype="float16"), |
| v_single: R.Tensor((4, 16, 1, 16), dtype="float16"), |
| ) -> R.Tensor((4, 16, 32, 8), dtype="float16"): |
| with R.dataflow(): |
| k = R.repeat(k_single, 32, axis=2) |
| v = R.repeat(v_single, 32, axis=2) |
| |
| lv = R.permute_dims(q, axes=[0, 2, 1, 3]) |
| lv1 = R.reshape(lv, R.shape([128, 16, 16])) |
| lv2 = R.permute_dims(k, axes=[0, 2, 1, 3]) |
| lv3 = R.reshape(lv2, R.shape([128, 16, 16])) |
| lv4 = R.permute_dims(v, axes=[0, 2, 1, 3]) |
| lv5 = R.reshape(lv4, R.shape([128, 16, 16])) |
| |
| lv6 = R.permute_dims(lv3, axes=[0, 2, 1]) |
| lv7 = R.matmul(lv1, lv6, out_dtype="float16") |
| lv3_1 = R.astype(R.const(0.25, "float32"), "float16") |
| lv8 = R.multiply(lv7, lv3_1) |
| lv11 = R.astype(R.nn.softmax(R.astype(lv8, "float32"), axis=2), "float16") |
| lv12 = R.matmul(lv11, lv5, out_dtype="float16") |
| lv13 = R.reshape(lv12, R.shape([4, 32, 16, 16])) |
| lv6_1 = R.permute_dims(lv13, axes=[0, 2, 1, 3]) |
| R.output(lv6_1) |
| return lv6_1 |
| |
| q_np = np.random.randn(4, 16, 32, 16).astype("float16") |
| k_np = np.random.randn(4, 16, 1, 16).astype("float16") |
| v_np = np.random.randn(4, 16, 1, 16).astype("float16") |
| args = [q_np, k_np, v_np] |
| ref = build_and_run(Module, args, "llvm", legalize=True) |
| |
| mod = partition_for_cutlass(Module, use_flash_mqa=True) |
| codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80}}) |
| mod = codegen_pass(mod) |
| |
| out = build_and_run(mod, args, "cuda") |
| |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| def _test_batched_var_len_attention( |
| mod, seq_lens, num_head, num_kv_head, head_size, window_size=None |
| ): |
| if not tvm.get_global_func("tvm.contrib.thrust.sum_scan", True): |
| return |
| |
| hidden_size = num_head * head_size |
| |
| batched_queries = [] |
| batched_keys = [] |
| batched_values = [] |
| batched_refs = [] |
| |
| for s in seq_lens: |
| q, k, v, _, ref = get_numpy_attention_ref( |
| 1, |
| s, |
| s, |
| num_head, |
| head_size, |
| head_size, |
| "none", |
| "none", |
| "BottomRight", |
| "float16", |
| num_kv_head=num_kv_head, |
| window_size=window_size, |
| ) |
| batched_queries.append(np.reshape(q, [-1, hidden_size])) |
| batched_keys.append(np.reshape(k, [-1, num_kv_head * head_size])) |
| batched_values.append(np.reshape(v, [-1, num_kv_head * head_size])) |
| batched_refs.append(np.reshape(ref, [-1, hidden_size])) |
| |
| batched_queries = np.vstack(batched_queries) |
| batched_keys = np.vstack(batched_keys) |
| batched_values = np.vstack(batched_values) |
| ref = np.vstack(batched_refs) |
| |
| mod = partition_for_cutlass(mod) |
| codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80}}) |
| mod = codegen_pass(mod) |
| |
| with tvm.target.Target("cuda"): |
| mod = relax.transform.LegalizeOps()(mod) |
| mod = tvm.tir.transform.DefaultGPUSchedule()(mod) |
| |
| out = build_and_run( |
| mod, |
| [ |
| batched_queries, |
| batched_keys, |
| batched_values, |
| np.array(seq_lens, dtype="int32"), |
| ], |
| "cuda", |
| ) |
| |
| ############# xformer reference for verification ############# |
| |
| # attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) |
| |
| # queries = torch.from_numpy(np.reshape(batched_queries, [1, -1, num_head, head_size])).to("cuda") |
| # keys = torch.from_numpy(np.reshape(batched_keys, [1, -1, num_head, head_size])).to("cuda") |
| # values = torch.from_numpy(np.reshape(batched_values, [1, -1, num_head, head_size])).to("cuda") |
| |
| # out = xops.memory_efficient_attention_forward( |
| # queries, keys, values, |
| # attn_bias=attn_bias, |
| # ).cpu().numpy()[0] |
| # out = np.reshape(out, [-1, hidden_size]) |
| |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| def test_batched_var_len_attention(): |
| @I.ir_module |
| class Module: |
| I.module_global_infos( |
| { |
| "vdevice": [ |
| I.vdevice("llvm"), |
| ] |
| } |
| ) |
| |
| @R.function |
| def main( |
| queries: R.Tensor(("num_tokens", 4096), dtype="float16"), |
| keys: R.Tensor(("num_tokens", 4096), dtype="float16"), |
| values: R.Tensor(("num_tokens", 4096), dtype="float16"), |
| seq_lens: R.Tensor(("num_seq",), dtype="int32"), |
| ) -> R.Tensor(("num_tokens", 4096), dtype="float16"): |
| R.func_attr({"num_input": 4}) |
| cls = Module |
| num_tokens = T.int64() |
| num_seq = T.int64() |
| |
| with R.dataflow(): |
| # TODO(masahi): Workaround for the broken Relax cumsum op on GPU. |
| # https://github.com/apache/tvm/issues/15851 |
| cumsum = R.call_dps_packed( |
| "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info |
| ) |
| max_seqlen_q = R.to_vdevice(R.max(seq_lens), "llvm:0") |
| seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum]) |
| q = R.reshape(queries, R.shape([1, num_tokens, 128, 32])) |
| k = R.reshape(keys, R.shape([1, num_tokens, 128, 32])) |
| v = R.reshape(values, R.shape([1, num_tokens, 128, 32])) |
| attn_out = R.nn.attention_var_len( |
| q, |
| k, |
| v, |
| seqstart_q, |
| max_seqlen_q, |
| causal_mask="BottomRight", |
| ) |
| out = R.reshape(attn_out, R.shape([num_tokens, 4096])) |
| R.output(out) |
| return out |
| |
| seq_lens = [5, 3, 8] |
| num_head = 128 |
| head_size = 32 |
| |
| _test_batched_var_len_attention(Module, seq_lens, num_head, num_head, head_size) |
| |
| |
| def test_batched_var_len_multi_query_attention(): |
| @I.ir_module |
| class Module: |
| I.module_global_infos( |
| { |
| "vdevice": [ |
| I.vdevice("llvm"), |
| ] |
| } |
| ) |
| |
| @R.function |
| def main( |
| queries: R.Tensor(("num_tokens", 4096), dtype="float16"), |
| keys: R.Tensor(("num_tokens", 512), dtype="float16"), |
| values: R.Tensor(("num_tokens", 512), dtype="float16"), |
| seq_lens: R.Tensor(("num_seq",), dtype="int32"), |
| ) -> R.Tensor(("num_tokens", 4096), dtype="float16"): |
| R.func_attr({"num_input": 4}) |
| cls = Module |
| num_tokens = T.int64() |
| num_seq = T.int64() |
| |
| with R.dataflow(): |
| # TODO(masahi): Workaround for the broken Relax cumsum op on GPU. |
| # https://github.com/apache/tvm/issues/15851 |
| cumsum = R.call_dps_packed( |
| "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info |
| ) |
| max_seqlen_q = R.to_vdevice(R.max(seq_lens), "llvm:0") |
| seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum]) |
| q = R.reshape(queries, R.shape([1, num_tokens, 128, 32])) |
| k = R.reshape(keys, R.shape([1, num_tokens, 16, 32])) |
| v = R.reshape(values, R.shape([1, num_tokens, 16, 32])) |
| attn_out = R.nn.attention_var_len( |
| q, |
| k, |
| v, |
| seqstart_q, |
| max_seqlen_q, |
| causal_mask="BottomRight", |
| ) |
| out = R.reshape(attn_out, R.shape([num_tokens, 4096])) |
| R.output(out) |
| return out |
| |
| seq_lens = [5, 3, 8] |
| num_head = 128 |
| num_kv_head = 16 |
| head_size = 32 |
| |
| _test_batched_var_len_attention(Module, seq_lens, num_head, num_kv_head, head_size) |
| |
| |
| def test_sliding_window(): |
| q_shape = (1, 64, 16, 8) |
| k_shape = v_shape = q_shape |
| window_size = 8 |
| causal = "BottomRight" |
| |
| mod = get_relax_attention_module( |
| q_shape, |
| k_shape, |
| v_shape, |
| dtype="float16", |
| causal_mask=causal, |
| window_size=window_size, |
| ) |
| |
| q, k, v, _, ref = get_numpy_attention_ref( |
| 1, 64, 64, 16, 8, 8, "none", "none", causal, "float16", window_size=window_size |
| ) |
| |
| out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) |
| |
| tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| ############# xformer reference for verification ############# |
| |
| # attn_bias = BlockDiagonalCausalMask.from_seqlens([64]) |
| |
| # if window_size > 0: |
| # attn_bias = attn_bias.make_local_attention(window_size) |
| |
| # query = torch.from_numpy(q).to("cuda") |
| # key = torch.from_numpy(k).to("cuda") |
| # value = torch.from_numpy(v).to("cuda") |
| |
| # ref = xops.memory_efficient_attention_forward( |
| # query, key, value, attn_bias=attn_bias, |
| # ).cpu().numpy() |
| |
| # tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
| |
| |
| def test_batched_var_len_sliding_window(): |
| @I.ir_module |
| class Module: |
| I.module_global_infos( |
| { |
| "vdevice": [ |
| I.vdevice("llvm"), |
| ] |
| } |
| ) |
| |
| @R.function |
| def main( |
| queries: R.Tensor(("num_tokens", 4096), dtype="float16"), |
| keys: R.Tensor(("num_tokens", 4096), dtype="float16"), |
| values: R.Tensor(("num_tokens", 4096), dtype="float16"), |
| seq_lens: R.Tensor(("num_seq",), dtype="int32"), |
| ) -> R.Tensor(("num_tokens", 4096), dtype="float16"): |
| R.func_attr({"num_input": 4}) |
| cls = Module |
| num_tokens = T.int64() |
| num_seq = T.int64() |
| |
| with R.dataflow(): |
| # TODO(masahi): Workaround for the broken Relax cumsum op on GPU. |
| # https://github.com/apache/tvm/issues/15851 |
| cumsum = R.call_dps_packed( |
| "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info |
| ) |
| max_seqlen_q = R.to_vdevice(R.max(seq_lens), "llvm:0") |
| seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum]) |
| q = R.reshape(queries, R.shape([1, num_tokens, 128, 32])) |
| k = R.reshape(keys, R.shape([1, num_tokens, 128, 32])) |
| v = R.reshape(values, R.shape([1, num_tokens, 128, 32])) |
| attn_out = R.nn.attention_var_len( |
| q, |
| k, |
| v, |
| seqstart_q, |
| max_seqlen_q, |
| causal_mask="BottomRight", |
| window_size=T.IntImm("int32", 8), |
| ) |
| out = R.reshape(attn_out, R.shape([num_tokens, 4096])) |
| R.output(out) |
| return out |
| |
| seq_lens = [64, 64, 64] |
| num_head = 128 |
| num_kv_head = 128 |
| head_size = 32 |
| window_size = 8 |
| |
| _test_batched_var_len_attention(Module, seq_lens, num_head, num_kv_head, head_size, window_size) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |