| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # ruff: noqa: F401 |
| |
| import tvm |
| import tvm.testing |
| from tvm import ir, tirx |
| from tvm.script import ir as I |
| from tvm.script import tirx as T |
| |
| |
| def test_reuse_in_sequential_bind(): |
| """De-dup sequential variable bindings""" |
| |
| # Manually construct the PrimFunc body, as SSA violations are |
| # not valid TIR, and may not be expressible in future versions |
| # of TVMSCript. |
| var = tirx.Var("var", "int32") |
| sequential_bindings = tirx.SeqStmt( |
| [ |
| tirx.Bind(var, 16), |
| tirx.Evaluate(var), |
| tirx.Bind(var, 32), |
| tirx.Evaluate(var), |
| ] |
| ) |
| before = tirx.PrimFunc([], sequential_bindings) |
| |
| @T.prim_func(private=True) |
| def expected(): |
| var1 = T.bind(T.int32(16)) |
| T.evaluate(var1) |
| var2 = T.bind(T.int32(32)) |
| T.evaluate(var2) |
| |
| mod = tvm.IRModule.from_expr(before) |
| mod = tvm.tirx.transform.ConvertSSA()(mod) |
| tvm.ir.assert_structural_equal(mod["main"], expected) |
| |
| |
| def test_reuse_in_nested_bind(): |
| """De-dup sequential bindings of the same variable. |
| |
| In the flat Bind model, all Binds are siblings in a SeqStmt. A second |
| Bind of the same variable redefines it for all subsequent siblings. |
| ConvertSSA should create a new variable for the second binding and |
| update all subsequent uses to refer to the new variable. |
| """ |
| |
| # Manually construct the PrimFunc body, as SSA violations are |
| # not valid TIR, and may not be expressible in future versions |
| # of TVMScript. |
| var = tirx.Var("var", "int32") |
| # Note: nested SeqStmt is flattened by the IR builder, so the input |
| # is actually a flat SeqStmt with 5 elements. |
| inner_seq = tirx.SeqStmt( |
| [ |
| tirx.Bind(var, 16), |
| tirx.Evaluate(var), |
| ] |
| ) |
| outer_seq = tirx.SeqStmt( |
| [ |
| tirx.Bind(var, 32), |
| tirx.Evaluate(var), |
| inner_seq, |
| tirx.Evaluate(var), |
| ] |
| ) |
| before = tirx.PrimFunc([], outer_seq) |
| |
| # In the flat model, the second Bind(var, 16) redefines var for |
| # ALL subsequent siblings including the last Evaluate. |
| var1 = tirx.Var("var", "int32") |
| var2 = tirx.Var("var", "int32") |
| expected_body = tirx.SeqStmt( |
| [ |
| tirx.Bind(var1, 32), |
| tirx.Evaluate(var1), |
| tirx.Bind(var2, 16), |
| tirx.Evaluate(var2), |
| tirx.Evaluate(var2), |
| ] |
| ) |
| expected = tirx.PrimFunc([], expected_body) |
| |
| mod = tvm.IRModule.from_expr(before) |
| mod = tvm.tirx.transform.ConvertSSA()(mod) |
| tvm.ir.assert_structural_equal(mod["main"], expected) |
| |
| |
| def test_reused_var_across_module(): |
| """De-duplicate Var bindings across entire module""" |
| |
| @T.prim_func(private=True) |
| def func(): |
| var = T.bind(10) |
| T.evaluate(var) |
| |
| before = tvm.IRModule( |
| { |
| "func_a": func.with_attr("global_symbol", "func_a"), |
| "func_b": func.with_attr("global_symbol", "func_b"), |
| } |
| ) |
| |
| @I.ir_module |
| class expected: |
| @T.prim_func |
| def func_a(): |
| var = T.int32(10) |
| T.evaluate(var) |
| |
| @T.prim_func |
| def func_b(): |
| var = T.int32(10) |
| T.evaluate(var) |
| |
| after = tvm.tirx.transform.ConvertSSA()(before) |
| tvm.ir.assert_structural_equal(after, expected) |
| |
| |
| def test_reused_parameter(): |
| """De-duplicate Var usage in parameters |
| |
| In this test, the same `tirx.Var` instance is used for the |
| parameter `n` in both functions. |
| """ |
| |
| @T.prim_func(private=True) |
| def func(n: T.int32): |
| T.evaluate(n) |
| |
| before = tvm.IRModule( |
| { |
| "func_a": func.with_attr("global_symbol", "func_a"), |
| "func_b": func.with_attr("global_symbol", "func_b"), |
| } |
| ) |
| |
| @I.ir_module |
| class expected: |
| @T.prim_func |
| def func_a(n: T.int32): |
| T.evaluate(n) |
| |
| @T.prim_func |
| def func_b(n: T.int32): |
| T.evaluate(n) |
| |
| after = tvm.tirx.transform.ConvertSSA()(before) |
| tvm.ir.assert_structural_equal(after, expected) |
| |
| |
| def test_reused_buffer_obj(): |
| """De-duplicate buffer usage across entire module""" |
| |
| @T.prim_func(private=True) |
| def func(a: T.handle("float32")): |
| A = T.decl_buffer(shape=1, dtype="float32", data=a) |
| T.evaluate(A[0]) |
| |
| before = tvm.IRModule( |
| { |
| "func_a": func.with_attr("global_symbol", "func_a"), |
| "func_b": func.with_attr("global_symbol", "func_b"), |
| } |
| ) |
| |
| @I.ir_module |
| class expected: |
| @T.prim_func |
| def func_a(a: T.handle("float32")): |
| A = T.decl_buffer(shape=1, dtype="float32", data=a) |
| T.evaluate(A[0]) |
| |
| @T.prim_func |
| def func_b(a: T.handle("float32")): |
| A = T.decl_buffer(shape=1, dtype="float32", data=a) |
| T.evaluate(A[0]) |
| |
| after = tvm.tirx.transform.ConvertSSA()(before) |
| tvm.ir.assert_structural_equal(after, expected) |
| |
| |
| def test_reused_buffer_parameter(): |
| """De-duplicate buffer_map across entire module""" |
| |
| @T.prim_func(private=True) |
| def func(A: T.Buffer(1, "float32")): |
| T.evaluate(A[0]) |
| |
| before = tvm.IRModule( |
| { |
| "func_a": func.with_attr("global_symbol", "func_a"), |
| "func_b": func.with_attr("global_symbol", "func_b"), |
| } |
| ) |
| |
| @I.ir_module |
| class expected: |
| @T.prim_func |
| def func_a(A: T.Buffer(1, "float32")): |
| T.evaluate(A[0]) |
| |
| @T.prim_func |
| def func_b(A: T.Buffer(1, "float32")): |
| T.evaluate(A[0]) |
| |
| after = tvm.tirx.transform.ConvertSSA()(before) |
| tvm.ir.assert_structural_equal(after, expected) |
| |
| |
| def test_no_change_if_already_ssa(): |
| """A module that is already SSA should be unchanged""" |
| |
| @I.ir_module |
| class before: |
| @T.prim_func |
| def func(A: T.Buffer(1, "float32")): |
| T.evaluate(A[0]) |
| |
| after = tvm.tirx.transform.ConvertSSA()(before) |
| tvm.ir.assert_structural_equal(before, after) |
| assert before.same_as(after) |
| |
| |
| def test_keep_duplicate_thread_idx_in_same_function(): |
| """Environment threads are treated as being at function scope |
| |
| The `"thread_extent"` attribute has some unique semantics. It |
| serves as the definition of the `tirx::Var` representing the |
| environment thread (e.g. `threadIdx.x` in CUDA). However, |
| multiple `"thread_extent"` attributes may co-exist in the same |
| PrimFunc. For the purpose of variable scope, use of the |
| `tirx::Var` is only allowed within the body of the `AttrStmt`. |
| However, for the purpose of well-formed-ness, all |
| `"thread_extent"` attributes must use the same IterVar instance |
| (e.g. `WarpIndexFinder` in `lower_warp_memory.cc` may throw an |
| error if multiple IterVar instances occur). |
| |
| If there are multiple `AttrStmt` with key `"thread_extent"` in a |
| single function (represented in TVMScript as `T.launch_thread`), |
| these should be treated as a definition of a single variable at |
| function scope, and should not be de-duplicated. |
| """ |
| |
| @I.ir_module |
| class before: |
| @T.prim_func |
| def main(A: T.Buffer([256], "float32")): |
| threadIdx_x = T.env_thread("threadIdx.x") |
| with T.launch_thread(threadIdx_x, 256): |
| A[threadIdx_x] = A[threadIdx_x] + 1.0 |
| |
| with T.launch_thread(threadIdx_x, 256): |
| A[threadIdx_x] = A[threadIdx_x] + 2.0 |
| |
| after = tvm.tirx.transform.ConvertSSA()(before) |
| tvm.ir.assert_structural_equal(after, before) |
| |
| |
| def test_de_duplicate_thread_idx_across_multiple_functions(): |
| """Environment threads are treated as being at function scope |
| |
| See `test_keep_duplicate_thread_idx_in_same_function` for background |
| information. |
| |
| If there are multiple functions in an IRModule, the `AttrStmt` |
| with key `"thread_extent"` in a single function (represented in |
| TVMScript as `T.launch_thread`), these should be treated as a |
| definition of a single variable at function scope, and should not |
| be de-duplicated. |
| |
| For this test case, the `AttrStmt` for `"thread_extent"` are |
| written explicitly, without using the usual `T.env_thread` and |
| `T.launch_thread`, as they cannot represent the duplciate |
| Var/IterVar usage across the two PrimFuncs. |
| """ |
| |
| threadIdx_x = tvm.tirx.Var("threadIdx_x", "int32") |
| |
| # threadIdx_x is defined outside |
| @I.ir_module(check_well_formed=False) |
| class before: |
| @T.prim_func |
| def kernel_1(A: T.Buffer([256], "float32")): |
| T.attr( |
| T.iter_var(threadIdx_x, T.Range(0, 256), "ThreadIndex", "threadIdx.x"), |
| "thread_extent", |
| 256, |
| ) |
| A[threadIdx_x] = A[threadIdx_x] + T.float32(1) |
| |
| @T.prim_func |
| def kernel_2(A: T.Buffer([256], "float32")): |
| T.attr( |
| T.iter_var(threadIdx_x, T.Range(0, 256), "ThreadIndex", "threadIdx.x"), |
| "thread_extent", |
| 256, |
| ) |
| A[threadIdx_x] = A[threadIdx_x] + T.float32(1) |
| |
| @I.ir_module |
| class expected: |
| @T.prim_func |
| def kernel_1(A: T.Buffer([256], "float32")): |
| threadIdx_x = T.int32() |
| T.attr( |
| T.iter_var(threadIdx_x, T.Range(0, 256), "ThreadIndex", "threadIdx.x"), |
| "thread_extent", |
| 256, |
| ) |
| A[threadIdx_x] = A[threadIdx_x] + T.float32(1) |
| |
| @T.prim_func |
| def kernel_2(A: T.Buffer([256], "float32")): |
| threadIdx_x = T.int32() |
| T.attr( |
| T.iter_var(threadIdx_x, T.Range(0, 256), "ThreadIndex", "threadIdx.x"), |
| "thread_extent", |
| 256, |
| ) |
| A[threadIdx_x] = A[threadIdx_x] + T.float32(1) |
| |
| after = tvm.tirx.transform.ConvertSSA()(before) |
| tvm.ir.assert_structural_equal(after, expected) |
| |
| |
| def test_de_duplicate_thread_idx_iter_var_across_multiple_functions(): |
| """Environment threads are treated as being at function scope |
| |
| Like `test_de_duplicate_thread_idx_across_multiple_functions`, except the |
| `IterVar` for the environment thread is duplicated across multiple |
| PrimFuncs, not just the `tirx.Var` inside the `IterVar`. |
| """ |
| |
| threadIdx_x = tvm.tirx.Var("threadIdx_x", "int32") |
| iter_var = tvm.tirx.IterVar( |
| tvm.ir.Range(0, 256), threadIdx_x, tvm.tirx.IterVar.ThreadIndex, "threadIdx.x" |
| ) |
| |
| # complaints of multiple definitions for threadIdx_x |
| @I.ir_module(check_well_formed=False) |
| class before: |
| @T.prim_func |
| def kernel_1(A: T.Buffer([256], "float32")): |
| T.attr(iter_var, "thread_extent", 256) |
| A[threadIdx_x] = A[threadIdx_x] + T.float32(1) |
| |
| @T.prim_func |
| def kernel_2(A: T.Buffer([256], "float32")): |
| T.attr(iter_var, "thread_extent", 256) |
| A[threadIdx_x] = A[threadIdx_x] + T.float32(1) |
| |
| @I.ir_module(check_well_formed=False) |
| class expected: |
| @T.prim_func |
| def kernel_1(A: T.Buffer([256], "float32")): |
| threadIdx_x = T.int32() |
| T.attr( |
| T.iter_var(threadIdx_x, T.Range(0, 256), "ThreadIndex", "threadIdx.x"), |
| "thread_extent", |
| 256, |
| ) |
| A[threadIdx_x] = A[threadIdx_x] + T.float32(1) |
| |
| @T.prim_func |
| def kernel_2(A: T.Buffer([256], "float32")): |
| threadIdx_x = T.int32() |
| T.attr( |
| T.iter_var(threadIdx_x, T.Range(0, 256), "ThreadIndex", "threadIdx.x"), |
| "thread_extent", |
| 256, |
| ) |
| A[threadIdx_x] = A[threadIdx_x] + T.float32(1) |
| |
| after = tvm.tirx.transform.ConvertSSA()(before) |
| tvm.ir.assert_structural_equal(after, expected) |
| |
| |
| def test_thread_idx_reused_within_and_across_functions(): |
| """Environment threads are treated as being at function scope |
| |
| A combination of |
| test_de_duplicate_thread_idx_iter_var_across_multiple_functions and |
| test_keep_duplicate_thread_idx_in_same_function. The re-use within a |
| function should be maintained, while re-use across functions is |
| de-duplicated. |
| """ |
| |
| threadIdx_x = tvm.tirx.Var("threadIdx_x", "int32") |
| iter_var = tvm.tirx.IterVar( |
| tvm.ir.Range(0, 256), threadIdx_x, tvm.tirx.IterVar.ThreadIndex, "threadIdx.x" |
| ) |
| |
| # complaints of multiple definitions of threadIdx_x |
| @I.ir_module(check_well_formed=False) |
| class before: |
| @T.prim_func |
| def kernel_1(A: T.Buffer([256], "float32")): |
| with T.attr(iter_var, "thread_extent", 256): |
| A[threadIdx_x] = A[threadIdx_x] + 1.0 |
| with T.attr(iter_var, "thread_extent", 256): |
| A[threadIdx_x] = A[threadIdx_x] + 2.0 |
| |
| @T.prim_func |
| def kernel_2(A: T.Buffer([256], "float32")): |
| with T.attr(iter_var, "thread_extent", 256): |
| A[threadIdx_x] = A[threadIdx_x] + 1.0 |
| with T.attr(iter_var, "thread_extent", 256): |
| A[threadIdx_x] = A[threadIdx_x] + 2.0 |
| |
| @I.ir_module |
| class expected: |
| @T.prim_func |
| def kernel_1(A: T.Buffer([256], "float32")): |
| threadIdx_x = T.env_thread("threadIdx.x") |
| with T.launch_thread(threadIdx_x, 256): |
| A[threadIdx_x] = A[threadIdx_x] + 1.0 |
| with T.launch_thread(threadIdx_x, 256): |
| A[threadIdx_x] = A[threadIdx_x] + 2.0 |
| |
| @T.prim_func |
| def kernel_2(A: T.Buffer([256], "float32")): |
| threadIdx_x = T.env_thread("threadIdx.x") |
| with T.launch_thread(threadIdx_x, 256): |
| A[threadIdx_x] = A[threadIdx_x] + 1.0 |
| with T.launch_thread(threadIdx_x, 256): |
| A[threadIdx_x] = A[threadIdx_x] + 2.0 |
| |
| after = tvm.tirx.transform.ConvertSSA()(before) |
| tvm.ir.assert_structural_equal(after, expected) |
| |
| |
| def test_track_forward_declarations_in_attr_stmt(): |
| """T.attr statements may refer to a about-to-be-defined tirx.Var""" |
| |
| # Generate the PrimFunc, which is already SSA |
| # |
| # This is constructed directly, rather than using TVMScript. |
| # This test case requires a `tirx.AttrStmt` that references a |
| # variable, followed by the `tirx.For` defining that variable. |
| # This is not expressible in TVMScript, as it only provides the |
| # loop iterator within the body of the loop. |
| i0_outer_outer = tirx.Var("i0_outer_outer", "int32") |
| i0_outer_inner = tirx.Var("i0_outer_inner", "int32") |
| i0_inner = tirx.Var("i0_inner", "int32") |
| |
| A = tirx.decl_buffer(1024, "float32", "A") |
| B = tirx.decl_buffer(1024, "float32", "B") |
| |
| index = i0_outer_outer * 52 + i0_outer_inner * 4 + i0_inner |
| |
| stmt = tirx.BufferStore(B, tirx.BufferLoad(A, [index]), [index]) |
| stmt = tirx.IfThenElse(i0_outer_outer * 13 + i0_outer_inner < 256, stmt, None) |
| stmt = tirx.For(i0_inner, 0, 4, tirx.ForKind.VECTORIZED, stmt) |
| stmt = tirx.For(i0_outer_inner, 0, 13, tirx.ForKind.PARALLEL, stmt) |
| stmt = tirx.AttrStmt( |
| T.iter_var(i0_outer_inner, None, "DataPar", ""), |
| "pragma_parallal_barrier_when_finish", |
| 1, |
| stmt, |
| ) |
| stmt = tirx.AttrStmt( |
| T.iter_var(i0_outer_inner, None, "DataPar", ""), |
| "pragma_parallal_stride_pattern", |
| 1, |
| stmt, |
| ) |
| stmt = tirx.For(i0_outer_outer, 0, 20, tirx.ForKind.SERIAL, stmt) |
| stmt = tirx.AttrStmt( |
| T.iter_var(i0_outer_outer, None, "DataPar", ""), |
| "pragma_parallal_launch_point", |
| 1, |
| stmt, |
| ) |
| |
| A_handle = tirx.Var("A_handle", "handle") |
| B_handle = tirx.Var("B_handle", "handle") |
| |
| before = tirx.PrimFunc( |
| [A_handle, B_handle], |
| stmt, |
| buffer_map={A_handle: A, B_handle: B}, |
| ) |
| |
| mod = tvm.IRModule.from_expr(before) |
| after = tvm.tirx.transform.ConvertSSA()(mod) |
| tvm.ir.assert_structural_equal(after["main"], before) |
| |
| |
| def test_shared_shape_var_in_buffer_map_and_alloc_buffer(): |
| """Shape var shared across buffer_map entries and AllocBuffer should not be renamed. |
| |
| When the same SizeVar (e.g., `n`) appears in multiple buffer_map |
| entries (A and B both have shape [n]), ConvertSSA should not treat |
| the second occurrence as a redefinition. All uses of `n` in the |
| function body (including AllocBuffer shapes) must remain the same |
| Var object so that MakePackedAPI can bind it from the DLTensor shape. |
| """ |
| n = tirx.SizeVar("n", "int32") |
| A_handle = tirx.Var("A_handle", "handle") |
| B_handle = tirx.Var("B_handle", "handle") |
| A = tirx.decl_buffer((n,), "float32", "A") |
| B = tirx.decl_buffer((n,), "float32", "B") |
| |
| # AllocBuffer with shape [n] in the body (flat, no body) |
| C = tirx.decl_buffer((n,), "float32", "C") |
| body = tirx.SeqStmt([tirx.AllocBuffer(C), tirx.Evaluate(1)]) |
| |
| before = tirx.PrimFunc( |
| [A_handle, B_handle], |
| body, |
| buffer_map={A_handle: A, B_handle: B}, |
| ) |
| |
| mod = tvm.IRModule.from_expr(before) |
| after = tvm.tirx.transform.ConvertSSA()(mod) |
| # The function is already SSA — ConvertSSA should not change it. |
| tvm.ir.assert_structural_equal(after["main"], before) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |