| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import re |
| |
| import numpy as np |
| |
| import tvm |
| import tvm.testing |
| from tvm.script import tirx as Tx |
| |
| |
| def generate_random_data(shape, dtype): |
| np.random.seed(0) |
| return np.random.randn(*shape).astype(dtype) |
| |
| |
| def create_tvm_arrays(data_np, device): |
| return [tvm.runtime.tensor(data, device=device) for data in data_np] |
| |
| |
| def build_and_run_tvm_func(sch, target, *args): |
| func = tvm.compile(sch.mod, target=target) |
| func(*args) |
| return func, args[-1] |
| |
| |
| def from_source(code): |
| return tvm.script.from_source(code, s_tir=True) |
| |
| |
| def verify_result(C_tvm, C_np): |
| tvm.testing.assert_allclose(C_tvm.numpy(), C_np, rtol=1e-5) |
| |
| |
| def verify_tir_code(code): |
| assert from_source(code).script() == code |
| |
| |
| def verify_cuda_code_array(func, dim_num, dtype, *dims): |
| generated_code = func.mod.imports[0].inspect_source() |
| |
| match = re.search(r"// print_buffer starts(.*?)// print_buffer ends", generated_code, re.DOTALL) |
| if not match: |
| raise AssertionError("print_buffer section not found in generated code") |
| |
| print_buffer_section = match.group(1).strip() |
| loop_pattern = re.compile(r"for \(int i(\d+) = 0; i\1 < (\d+); \+\+i\1\)") |
| loops = loop_pattern.findall(print_buffer_section) |
| if len(loops) != dim_num: |
| raise AssertionError(f"Expected {dim_num} nested loops, but found {len(loops)}") |
| |
| loop_limits = [int(limit) for _, limit in loops] |
| if loop_limits != list(dims): |
| raise AssertionError(f"Expected loop limits {dims}, but found {loop_limits}") |
| |
| dtype_to_printf = {"float32": "%f", "float16": "%f", "int32": "%d", "uint32": "%u"} |
| expected_printf_specifier = dtype_to_printf.get(dtype) |
| if not expected_printf_specifier: |
| raise AssertionError(f"Unsupported dtype {dtype}") |
| variable_access_pattern = r"\w+\[.*\]" |
| |
| if dtype == "float16": |
| # Look for `printf("%f", static_cast<float>(C[...]))` |
| printf_pattern = re.compile( |
| r'printf\s*\(\s*"' |
| + re.escape(expected_printf_specifier) |
| + r'"\s*,\s*static_cast<float>\(' |
| + variable_access_pattern |
| + r"\)\s*\)" |
| ) |
| else: |
| # Look for `printf("%f", C[...])` |
| printf_pattern = re.compile( |
| r'printf\s*\(\s*"' |
| + re.escape(expected_printf_specifier) |
| + r'"\s*,\s*' |
| + variable_access_pattern |
| + r"\s*\)" |
| ) |
| |
| if not printf_pattern.search(print_buffer_section): |
| raise AssertionError( |
| f'Expected element printf statement with format "{expected_printf_specifier}" and a buffer access, but not found' # noqa: E501 |
| ) |
| |
| |
| def verify_cuda_code_scalar(func, dtype, expected_value_or_varname): |
| generated_code = func.mod.imports[0].inspect_source() |
| |
| all_print_blocks = re.findall( |
| r"// print_buffer starts(.*?)// print_buffer ends", generated_code, re.DOTALL |
| ) |
| if not all_print_blocks: |
| raise AssertionError("No print_buffer sections found in generated code") |
| |
| dtype_to_printf = {"float32": "%f", "float16": "%f", "int32": "%d", "uint32": "%u"} |
| expected_printf = dtype_to_printf.get(dtype) |
| if not expected_printf: |
| raise AssertionError(f"Unsupported dtype for scalar verification: {dtype}") |
| |
| value_pattern = "" |
| if isinstance(expected_value_or_varname, int | float): |
| if "float" in dtype: |
| value_pattern = re.escape(str(float(expected_value_or_varname))) + "f?" |
| else: |
| value_pattern = re.escape(str(int(expected_value_or_varname))) |
| elif isinstance(expected_value_or_varname, str): |
| value_pattern = re.escape(expected_value_or_varname) |
| else: |
| raise TypeError( |
| "expected_value_or_varname must be a number (for literals) or a string (for variables)" |
| ) |
| |
| if dtype == "float16": |
| printf_pattern = re.compile( |
| r'printf\s*\(\s*".*?' |
| + re.escape(expected_printf) |
| + r'.*?",\s*static_cast<float>\(\s*' |
| + value_pattern |
| + r"\s*\)\s*\)" |
| ) |
| else: |
| printf_pattern = re.compile( |
| r'printf\s*\(\s*".*?' |
| + re.escape(expected_printf) |
| + r'.*?",\s*' |
| + value_pattern |
| + r"\s*\)" |
| ) |
| |
| for block in all_print_blocks: |
| if printf_pattern.search(block): |
| return |
| |
| raise AssertionError( |
| f'Could not find a scalar printf with format "{expected_printf}" and value/variable ' |
| f'"{expected_value_or_varname}" in any print_buffer block.' |
| ) |
| |
| |
| def verify_cuda_code_string(func, expected_var_name, expected_string_literal): |
| generated_code = func.mod.imports[0].inspect_source() |
| |
| all_print_blocks = re.findall( |
| r"// print_buffer starts(.*?)// print_buffer ends", generated_code, re.DOTALL |
| ) |
| if not all_print_blocks: |
| raise AssertionError("No print_buffer sections found in generated code") |
| |
| var_printf_pattern = re.compile( |
| r'printf\s*\(\s*".*?%s.*?",\s*\(char\*\)' + re.escape(expected_var_name) + r"\s*\)" |
| ) |
| literal_printf_pattern = re.compile( |
| r'printf\s*\(\s*".*?%s.*?",\s*\(char\*\)\s*"' |
| + re.escape(expected_string_literal) |
| + r'"\s*\)' |
| ) |
| |
| for block in all_print_blocks: |
| if var_printf_pattern.search(block) or literal_printf_pattern.search(block): |
| return |
| |
| raise AssertionError( |
| f'Could not find a string printf using variable "{expected_var_name}" or ' |
| f'string literal "{expected_string_literal}" in any print_buffer block.' |
| ) |
| |
| |
| def test_print(): |
| DEV = tvm.cuda() |
| target = tvm.target.Target("cuda") |
| |
| def test_vector_add_1D(dtype, dtype_str): |
| M = 6 |
| M_BLK = 6 |
| dim_num = 1 |
| A_np, B_np = generate_random_data((M,), dtype), generate_random_data((M,), dtype) |
| C_np = A_np + B_np |
| A_tvm, B_tvm = create_tvm_arrays([A_np, B_np], DEV) |
| |
| @Tx.prim_func(s_tir=True) |
| def add_func(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: |
| A = Tx.match_buffer(A_ptr, (M,), dtype_str) |
| B = Tx.match_buffer(B_ptr, (M,), dtype_str) |
| C = Tx.match_buffer(C_ptr, (M,), dtype_str) |
| |
| for i in Tx.grid(M): |
| with Tx.sblock("C"): |
| vi = Tx.axis.spatial(M, i) |
| C[vi] = A[vi] + B[vi] |
| Tx.print_buffer(C.data, dtype_str, False, False, dim_num, (M,)) |
| |
| sch = tvm.s_tir.Schedule(add_func) |
| blk = sch.get_sblock("C") |
| i = sch.get_loops(blk)[0] |
| |
| i0, i1 = sch.split(i, factors=[None, M_BLK]) |
| |
| sch.bind(i0, "blockIdx.x") |
| sch.bind(i1, "threadIdx.x") |
| |
| C_np_tmp = np.zeros((M,), dtype=dtype) |
| C_tvm = tvm.runtime.tensor(C_np_tmp, device=DEV) |
| func, C_tvm = build_and_run_tvm_func(sch, target, A_tvm, B_tvm, C_tvm) |
| verify_result(C_tvm, C_np) |
| verify_tir_code(add_func.script()) |
| verify_cuda_code_array(func, dim_num, dtype_str, M) |
| |
| def test_vector_add_2D(dtype, dtype_str): |
| M, N = 6, 6 |
| M_BLK, N_BLK = 6, 6 |
| dim_num = 2 |
| A_np, B_np = generate_random_data((M, N), dtype), generate_random_data((M, N), dtype) |
| C_np = A_np + B_np |
| A_tvm, B_tvm = create_tvm_arrays([A_np, B_np], DEV) |
| |
| @Tx.prim_func(s_tir=True) |
| def add_func(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: |
| A = Tx.match_buffer(A_ptr, (M, N), dtype_str) |
| B = Tx.match_buffer(B_ptr, (M, N), dtype_str) |
| C = Tx.match_buffer(C_ptr, (M, N), dtype_str) |
| |
| for i, j in Tx.grid(M, N): |
| with Tx.sblock("C"): |
| vi = Tx.axis.spatial(M, i) |
| vj = Tx.axis.spatial(N, j) |
| C[vi, vj] = A[vi, vj] + B[vi, vj] |
| Tx.print_buffer(C.data, C.dtype, False, False, dim_num, (M, N)) |
| |
| sch = tvm.s_tir.Schedule(add_func) |
| blk = sch.get_sblock("C") |
| i, j = sch.get_loops(blk) |
| |
| i0, i1 = sch.split(i, factors=[None, M_BLK]) |
| j0, j1 = sch.split(j, factors=[None, N_BLK]) |
| |
| sch.bind(i0, "blockIdx.x") |
| sch.bind(j0, "blockIdx.y") |
| sch.bind(i1, "threadIdx.x") |
| sch.bind(j1, "threadIdx.y") |
| |
| C_np_tmp = np.zeros((M, N), dtype=dtype) |
| C_tvm = tvm.runtime.tensor(C_np_tmp, device=DEV) |
| func, C_tvm = build_and_run_tvm_func(sch, target, A_tvm, B_tvm, C_tvm) |
| verify_result(C_tvm, C_np) |
| verify_tir_code(add_func.script()) |
| verify_cuda_code_array(func, dim_num, dtype_str, M, N) |
| |
| def test_vector_add_3D(dtype, dtype_str): |
| M, N, K = 6, 6, 6 |
| M_BLK, N_BLK, K_BLK = 6, 6, 6 |
| dim_num = 3 |
| A_np, B_np = generate_random_data((M, N, K), dtype), generate_random_data((M, N, K), dtype) |
| C_np = A_np + B_np |
| |
| A_tvm, B_tvm = create_tvm_arrays([A_np, B_np], DEV) |
| |
| @Tx.prim_func(s_tir=True) |
| def add_func(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: |
| A = Tx.match_buffer(A_ptr, (M, N, K), dtype_str) |
| B = Tx.match_buffer(B_ptr, (M, N, K), dtype_str) |
| C = Tx.match_buffer(C_ptr, (M, N, K), dtype_str) |
| |
| for i, j, k in Tx.grid(M, N, K): |
| with Tx.sblock("C"): |
| vi = Tx.axis.spatial(M, i) |
| vj = Tx.axis.spatial(N, j) |
| vk = Tx.axis.spatial(K, k) |
| C[vi, vj, vk] = A[vi, vj, vk] + B[vi, vj, vk] |
| Tx.print_buffer(C.data, C.dtype, False, False, dim_num, (M, N, K)) |
| |
| sch = tvm.s_tir.Schedule(add_func) |
| blk = sch.get_sblock("C") |
| i, j, k = sch.get_loops(blk) |
| |
| i0, i1 = sch.split(i, factors=[None, M_BLK]) |
| j0, j1 = sch.split(j, factors=[None, N_BLK]) |
| k0, k1 = sch.split(k, factors=[None, K_BLK]) |
| |
| sch.bind(i0, "blockIdx.x") |
| sch.bind(j0, "blockIdx.y") |
| sch.bind(k0, "blockIdx.z") |
| sch.bind(i1, "threadIdx.x") |
| sch.bind(j1, "threadIdx.y") |
| sch.bind(k1, "threadIdx.z") |
| |
| C_np_tmp = np.zeros((M, N, K), dtype=dtype) |
| C_tvm = tvm.runtime.tensor(C_np_tmp, device=DEV) |
| func, C_tvm = build_and_run_tvm_func(sch, target, A_tvm, B_tvm, C_tvm) |
| verify_result(C_tvm, C_np) |
| verify_tir_code(add_func.script()) |
| verify_cuda_code_array(func, dim_num, dtype_str, M, N, K) |
| |
| def test_const_scalar(dtype, dtype_str): |
| M = 6 |
| M_BLK = 6 |
| dim_num = 1 |
| A_np, B_np = generate_random_data((M,), dtype), generate_random_data((M,), dtype) |
| C_np = A_np + B_np |
| A_tvm, B_tvm = create_tvm_arrays([A_np, B_np], DEV) |
| |
| @Tx.prim_func(s_tir=True) |
| def add_func(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: |
| A = Tx.match_buffer(A_ptr, (M,), dtype_str) |
| B = Tx.match_buffer(B_ptr, (M,), dtype_str) |
| C = Tx.match_buffer(C_ptr, (M,), dtype_str) |
| Ten: Tx.let = Tx.IntImm(dtype_str, 10) |
| |
| for i in Tx.grid(M): |
| with Tx.sblock("C"): |
| vi = Tx.axis.spatial(M, i) |
| C[vi] = A[vi] + B[vi] |
| Tx.print_buffer(Ten, "int32", False, True, dim_num, ()) |
| |
| sch = tvm.s_tir.Schedule(add_func) |
| blk = sch.get_sblock("C") |
| i = sch.get_loops(blk)[0] |
| |
| i0, i1 = sch.split(i, factors=[None, M_BLK]) |
| |
| sch.bind(i0, "blockIdx.x") |
| sch.bind(i1, "threadIdx.x") |
| |
| C_np_tmp = np.zeros((M,), dtype=dtype) |
| C_tvm = tvm.runtime.tensor(C_np_tmp, device=DEV) |
| func, C_tvm = build_and_run_tvm_func(sch, target, A_tvm, B_tvm, C_tvm) |
| verify_result(C_tvm, C_np) |
| verify_tir_code(add_func.script()) |
| verify_cuda_code_scalar(func, dtype_str, 10) |
| |
| def test_string(dtype, dtype_str, test_string): |
| M = 6 |
| M_BLK = 6 |
| dim_num = 1 |
| A_np, B_np = generate_random_data((M,), dtype), generate_random_data((M,), dtype) |
| C_np = A_np + B_np |
| A_tvm, B_tvm = create_tvm_arrays([A_np, B_np], DEV) |
| |
| @Tx.prim_func(s_tir=True) |
| def add_func(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: |
| A = Tx.match_buffer(A_ptr, (M,), dtype_str) |
| B = Tx.match_buffer(B_ptr, (M,), dtype_str) |
| C = Tx.match_buffer(C_ptr, (M,), dtype_str) |
| string_var = Tx.StringImm(test_string) |
| |
| for i in Tx.grid(M): |
| with Tx.sblock("C"): |
| vi = Tx.axis.spatial(M, i) |
| C[vi] = A[vi] + B[vi] |
| Tx.print_buffer(string_var, "int8", True, False, dim_num, ()) |
| |
| sch = tvm.s_tir.Schedule(add_func) |
| blk = sch.get_sblock("C") |
| i = sch.get_loops(blk)[0] |
| |
| i0, i1 = sch.split(i, factors=[None, M_BLK]) |
| |
| sch.bind(i0, "blockIdx.x") |
| sch.bind(i1, "threadIdx.x") |
| |
| C_np_tmp = np.zeros((M,), dtype=dtype) |
| C_tvm = tvm.runtime.tensor(C_np_tmp, device=DEV) |
| func, C_tvm = build_and_run_tvm_func(sch, target, A_tvm, B_tvm, C_tvm) |
| verify_result(C_tvm, C_np) |
| verify_tir_code(add_func.script()) |
| verify_cuda_code_string(func, "string_var", test_string) |
| |
| test_vector_add_1D(np.float32, "float32") |
| test_vector_add_2D(np.int32, "int32") |
| test_vector_add_2D(np.float16, "float16") |
| test_vector_add_3D(np.uint32, "uint32") |
| test_string(np.float32, "float32", "hello tirx!") |
| test_const_scalar(np.int32, "int32") |
| |
| |
| if __name__ == "__main__": |
| test_print() |