| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from typing import List, Set, Union |
| |
| import tvm |
| import tvm.testing |
| from tvm import relax as rx |
| from tvm import tir |
| from tvm.relax.analysis import ( |
| all_global_vars, |
| all_vars, |
| bound_vars, |
| free_vars, |
| has_reshape_pattern, |
| name_to_binding, |
| remove_all_unused, |
| udchain, |
| ) |
| from tvm.script import ir as I |
| from tvm.script import relax as R |
| from tvm.script import tir as T |
| |
| |
| def var_name_set(vars: List[Union[rx.Var, rx.GlobalVar]]) -> Set[str]: |
| return set(map(lambda v: v.name_hint, vars)) |
| |
| |
| def test_use_def(): |
| m = tir.Var("m", "int64") |
| n = tir.Var("n", "int64") |
| x = rx.Var("x", R.Tensor([m, n], "float16")) |
| y = rx.Var("y", R.Tensor([n], "float16")) |
| ib = rx.BlockBuilder() |
| with ib.function("func", [x, y]): |
| with ib.dataflow(): |
| lv0 = ib.emit(rx.op.add(x, y)) |
| lv1 = ib.emit(rx.op.multiply(lv0, y)) |
| gv0 = ib.emit_output(lv1) |
| ib.emit_func_output(gv0) |
| dfb = ib.get()["func"].body.blocks[0] |
| udc = udchain(dfb) |
| assert set(udc[x]) == {lv0} |
| assert set(udc[y]) == {lv0, lv1} |
| assert set(udc[lv0]) == {lv1} |
| assert set(udc[lv1]) == {gv0} |
| assert set(udc[gv0]) == set() |
| |
| |
| def test_chained_remove_all_unused(): |
| @tvm.script.ir_module |
| class IdentityUnused: |
| @R.function |
| def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| with R.dataflow(): |
| lv0 = x |
| unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) |
| unused1 = R.call_dps_packed( |
| "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32") |
| ) |
| R.output(lv0) |
| return lv0 |
| |
| optimized = remove_all_unused(IdentityUnused["main"]) |
| |
| @tvm.script.ir_module |
| class GroundTruth: |
| @R.function |
| def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| with R.dataflow(): |
| lv0 = x |
| R.output(lv0) |
| return lv0 |
| |
| tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) |
| |
| |
| def test_binding_block_remove_all_unused(): |
| """Remove unused dataflow bindings |
| |
| Removal of unused bindings may not remove side effects. Since |
| bindings within a dataflow block are guaranteed not to have side |
| effects, they may be removed if unused. |
| """ |
| |
| @tvm.script.ir_module |
| class IdentityUnused: |
| @R.function(pure=False) |
| def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| with R.dataflow(): |
| lv0 = x |
| unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) |
| unused1 = R.call_dps_packed( |
| "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32") |
| ) |
| R.output(lv0) |
| z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) |
| return z |
| |
| optimized = remove_all_unused(IdentityUnused["main"]) |
| |
| @tvm.script.ir_module |
| class GroundTruth: |
| @R.function(pure=False) |
| def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| with R.dataflow(): |
| lv0 = x |
| R.output(lv0) |
| z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) |
| return z |
| |
| tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) |
| |
| |
| def test_binding_block_remove_unused_pure_without_dataflow(): |
| """Remove unused dataflow bindings |
| |
| Removal of unused bindings may not remove side effects. Unused |
| bindings whose value is a pure operation |
| (e.g. `R.call_dps_packed`) may be removed, even if outside of a |
| dataflow block. |
| """ |
| |
| @R.function(private=True) |
| def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| lv0 = x |
| unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) |
| unused1 = R.call_dps_packed("my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32")) |
| return x |
| |
| @R.function(private=True) |
| def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| return x |
| |
| after = remove_all_unused(before) |
| tvm.ir.assert_structural_equal(expected, after) |
| |
| |
| def test_binding_block_keep_impure_without_dataflow(): |
| """Remove unused dataflow bindings |
| |
| Removal of unused bindings may not remove side effects. Unused |
| bindings whose value is an impure operation (e.g. `R.call_packed`) |
| may not be removed, as outside of a dataflow block they may |
| contain side effects. |
| """ |
| |
| @R.function(private=True, pure=False) |
| def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| lv0 = x |
| y = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) |
| return y |
| |
| expected = before |
| |
| after = remove_all_unused(before) |
| tvm.ir.assert_structural_equal(expected, after) |
| |
| |
| def test_binding_block_keep_pure_func_used_only_for_impure(): |
| """Keep bindings that are used for impure functions |
| |
| Removal of unused bindings may not result in use of undefined |
| variables. Unused bindings whose value is an impure operation |
| (e.g. `R.call_packed`) may not be removed, nor may any of their |
| inputs. |
| |
| This is a regression test to catch an earlier failure mode, in |
| which tracking of unused variables only back-propagated from the |
| return value of functions, and did not consider variables that |
| were required to execute impure function calls. In that failure |
| mode, the binding of `y` would be removed as unused, even though |
| it was required to evaluate the packed function. |
| """ |
| |
| @R.function(private=True, pure=False) |
| def before(x: R.Tensor((32, 32), "int32")): |
| y = x * R.const(2) |
| z = R.call_packed( |
| "function_maybe_with_side_effects", y, sinfo_args=(R.Tensor((32, 32), "int32")) |
| ) |
| return R.tuple() |
| |
| expected = before |
| |
| after = remove_all_unused(before) |
| tvm.ir.assert_structural_equal(expected, after) |
| |
| |
| def test_binding_block_remove_all_unused_func_without_dataflow(): |
| @tvm.script.ir_module |
| class IdentityUnused: |
| @R.function(pure=False) |
| def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| lv0 = x |
| |
| @R.function |
| def internal_unused_func(A: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| return A |
| |
| z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) |
| return z |
| |
| optimized = remove_all_unused(IdentityUnused["main"]) |
| |
| @tvm.script.ir_module |
| class GroundTruth: |
| @R.function(pure=False) |
| def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| lv0 = x |
| z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) |
| return z |
| |
| tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) |
| |
| |
| def test_binding_block_fake_unused_remove_all_unused(): |
| @tvm.script.ir_module |
| class IdentityUnused: |
| @R.function(pure=False) |
| def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| with R.dataflow(): |
| lv0 = x |
| R.output(lv0) |
| z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) |
| return lv0 |
| |
| optimized = remove_all_unused(IdentityUnused["main"]) |
| |
| @tvm.script.ir_module |
| class GroundTruth: |
| @R.function(pure=False) |
| def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| with R.dataflow(): |
| lv0 = x |
| R.output(lv0) |
| # This might bring side effect so cannot be removed. |
| z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) |
| return lv0 |
| |
| tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) |
| |
| |
| def test_edge_binding_block_fake_unused_remove_all_unused(): |
| @tvm.script.ir_module |
| class IdentityUnused: |
| @R.function(pure=False) |
| def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): |
| z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((32, 32), "float32"))) |
| return x |
| |
| optimized = remove_all_unused(IdentityUnused["main"]) |
| tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"]) |
| |
| |
| def test_edge_binding_block_fake_unused_remove_all_unused2(): |
| @tvm.script.ir_module |
| class IdentityUnused: |
| @R.function |
| def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor(dtype="int32", ndim=3): |
| m = T.int64() |
| n = T.int64() |
| k = T.int64() |
| with R.dataflow(): |
| lv: R.Shape(ndim=3) = R.call_pure_packed( |
| "vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape(ndim=3),) |
| ) |
| lv1: R.Shape([m, n, k]) = R.match_cast(lv, R.Shape([m, n, k])) |
| gv: R.Tensor((m, n, k), dtype="int32") = R.full( |
| R.shape([m, n, k]), R.const(1, "int32"), dtype="int32" |
| ) |
| R.output(gv) |
| return gv |
| |
| optimized = remove_all_unused(IdentityUnused["main"]) |
| tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"]) |
| |
| |
| def test_remove_all_unused_from_dataflow_block(): |
| """Like test_chained_remove_all_unused, but on a SeqExpr""" |
| |
| @R.function |
| def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| with R.dataflow(): |
| lv0 = x |
| unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) |
| unused1 = R.call_dps_packed( |
| "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32") |
| ) |
| R.output(lv0) |
| return lv0 |
| |
| @R.function |
| def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| with R.dataflow(): |
| lv0 = x |
| R.output(lv0) |
| return lv0 |
| |
| after = remove_all_unused(before.body) |
| tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True) |
| |
| |
| def test_remove_all_unused_from_binding_block(): |
| """Like test_chained_remove_all_unused, but on a SeqExpr""" |
| |
| @R.function |
| def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| lv0 = x |
| unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) |
| unused1 = R.call_dps_packed("my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32")) |
| return lv0 |
| |
| @R.function |
| def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| lv0 = x |
| return lv0 |
| |
| after = remove_all_unused(before.body) |
| tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True) |
| |
| |
| def test_retain_impure_calls_unused_in_binding_block(): |
| """An impure call may have side effects, and must be kept""" |
| |
| @R.function(pure=False) |
| def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| lv0 = x |
| unused0 = R.call_packed("my_impure_call", x, sinfo_args=R.Tensor((32, 32), dtype="float32")) |
| unused1 = R.call_dps_packed("my_unused_call", (lv0,), R.Tensor((32, 32), dtype="float32")) |
| return lv0 |
| |
| @R.function(pure=False) |
| def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| lv0 = x |
| unused0 = R.call_packed("my_impure_call", x, sinfo_args=R.Tensor((32, 32), dtype="float32")) |
| return lv0 |
| |
| after = remove_all_unused(before.body) |
| tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True) |
| |
| |
| def test_retain_calls_to_impure_builtin_ops(): |
| @I.ir_module |
| class Module: |
| @T.prim_func(private=True) |
| def my_tir(A: T.handle, B: T.handle, n: T.int64): |
| T.evaluate(0) |
| |
| @R.function(pure=False) |
| def main(x: R.Tensor(("n",), "float32")): |
| cls = Module |
| n = T.int64() |
| storage = R.memory.alloc_storage((n * 4,), 0, "global", "float32") |
| alloc = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n]), "float32") |
| # "call_tir_dyn" is impure which shouldn't be removed. |
| R.vm.call_tir_dyn(cls.my_tir, (x, alloc, R.shape([n]))) |
| # "kill_tensor"/"kill_storage" are impure which shouldn't be removed. |
| R.memory.kill_tensor(alloc) |
| R.memory.kill_storage(storage) |
| return x |
| |
| after = remove_all_unused(Module["main"]) |
| tvm.ir.assert_structural_equal(after, Module["main"], map_free_vars=True) |
| |
| |
| def test_name_to_binding_var_shadowing(): |
| @R.function |
| def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| with R.dataflow(): |
| lv0 = x |
| lv1 = lv0 |
| R.output(lv1) |
| |
| with R.dataflow(): |
| lv0 = lv1 # shadowing |
| lv2 = lv0 |
| R.output(lv2) |
| return lv2 |
| |
| n2binding = name_to_binding(main) |
| |
| assert "lv0" in n2binding |
| assert "lv1" in n2binding |
| assert "lv2" in n2binding |
| |
| assert len(n2binding["lv0"]) == 2 |
| |
| |
| @tvm.script.ir_module |
| class VarExample: |
| @R.function |
| def func(a: R.Tensor) -> R.Tensor: |
| # normalized into assigning R.add(a, a) to a var and returning it |
| return R.add(a, a) |
| |
| @R.function |
| def main(x: R.Tensor, y: R.Tensor) -> R.Tensor: |
| cls = VarExample |
| z = R.add(x, y) |
| # no binding here |
| _ = R.match_cast(x, R.Tensor((5, 5))) |
| with R.dataflow(): |
| q = R.add(z, z) |
| p = cls.func(q) |
| r = R.match_cast(p, R.Tensor((5, 5))) |
| s = r |
| R.output(s) |
| return s |
| |
| |
| def test_all_vars(): |
| vars = all_vars(VarExample["func"]) |
| assert len(vars) == 2 |
| assert vars[0].name_hint == "a" |
| # the body of the seq expr in the func body is a var |
| assert vars[1] == VarExample["func"].body.body |
| |
| var_names = var_name_set(all_vars(VarExample["main"])) |
| assert var_names == {"_", "x", "y", "z", "p", "q", "r", "s"} |
| |
| |
| def test_all_vars_from_expr_using_dataflow(): |
| """all_vars() should return all Var, including DataflowVar""" |
| func = VarExample["main"] |
| cls_func_q = func.body.blocks[1].bindings[1].value |
| |
| var_names = var_name_set(all_vars(cls_func_q)) |
| assert var_names == {"q"} |
| |
| |
| def test_bound_vars(): |
| vars = bound_vars(VarExample["func"]) |
| assert len(vars) == 2 |
| assert vars[0].name_hint == "a" |
| # the body of the seq expr in the func body is a bound var |
| assert vars[1] == VarExample["func"].body.body |
| |
| # all the vars are bound |
| var_names = var_name_set(bound_vars(VarExample["main"])) |
| assert var_names == {"_", "x", "y", "z", "p", "q", "r", "s"} |
| |
| # if we consider only the body, then the function arguments are not bound |
| body_names = var_name_set(bound_vars(VarExample["main"].body)) |
| assert body_names == {"_", "z", "p", "q", "r", "s"} |
| |
| # only binding is in the (normalized) body |
| simple_body_vars = bound_vars(VarExample["func"].body) |
| assert len(simple_body_vars) == 1 |
| assert simple_body_vars[0] == VarExample["func"].body.body |
| |
| |
| def test_free_vars(): |
| # all the vars are bound |
| assert len(free_vars(VarExample["func"])) == 0 |
| assert len(free_vars(VarExample["main"])) == 0 |
| |
| # the arguments are free if we look only at the bodies |
| func_free = var_name_set(free_vars(VarExample["func"].body)) |
| main_free = var_name_set(free_vars(VarExample["main"].body)) |
| assert len(func_free) == 1 |
| assert len(main_free) == 2 |
| assert "a" in func_free |
| assert main_free == {"x", "y"} |
| |
| # function that captures vars |
| x = rx.Var("x", R.Tensor(ndim=-1)) |
| y = rx.Var("y", R.Tensor(ndim=-1)) |
| z = rx.Var("z", R.Tensor(ndim=-1)) |
| inner = rx.Function( |
| [z], |
| rx.op.add(x, rx.op.add(y, z)), |
| ret_struct_info=R.Tensor(ndim=-1), |
| ) |
| outer = rx.Function( |
| [x, y], |
| rx.Call(inner, [y]), |
| ret_struct_info=R.Tensor(ndim=-1), |
| ) |
| assert len(free_vars(outer)) == 0 |
| assert var_name_set(free_vars(inner)) == {"x", "y"} |
| |
| |
| def test_all_global_vars(): |
| # there is one call to "func" |
| global_vars = all_global_vars(VarExample["main"]) |
| assert len(global_vars) == 1 |
| assert global_vars[0].name_hint == "func" |
| |
| gv1 = rx.GlobalVar("gv1") |
| gv2 = rx.GlobalVar("gv2") |
| gv3 = rx.GlobalVar("gv3") |
| call = rx.Call(gv1, [gv2, gv3]) |
| call_var_names = var_name_set(all_global_vars(call)) |
| assert call_var_names == {"gv1", "gv2", "gv3"} |
| |
| |
| def test_reshape_pattern_reshape(): |
| @T.prim_func |
| def reshape( |
| rxplaceholder: T.Buffer((1, 2, 3, 4), "float32"), |
| T_reshape: T.Buffer((8, 3), "float32"), |
| ): |
| for i0, i1 in T.grid(8, 3): |
| with T.block("T_reshape"): |
| ax0, ax1 = T.axis.remap("SS", [i0, i1]) |
| T.reads( |
| rxplaceholder[ |
| (ax0 * 3 + ax1) // 24, |
| (ax0 * 3 + ax1) % 24 // 12, |
| (ax0 * 3 + ax1) % 12 // 4, |
| (ax0 * 3 + ax1) % 4, |
| ] |
| ) |
| T.writes(T_reshape[ax0, ax1]) |
| T_reshape[ax0, ax1] = rxplaceholder[ |
| (ax0 * 3 + ax1) // 24, |
| (ax0 * 3 + ax1) % 24 // 12, |
| (ax0 * 3 + ax1) % 12 // 4, |
| (ax0 * 3 + ax1) % 4, |
| ] |
| |
| assert has_reshape_pattern(reshape) |
| |
| |
| def test_reshape_pattern_reshape_scheduled(): |
| @T.prim_func |
| def reshape_scheduled( |
| rxplaceholder: T.Buffer((1, 2, 3, 4), "float32"), |
| T_reshape: T.Buffer((8, 3), "float32"), |
| ): |
| for i0_i1_fused_0 in T.thread_binding(1, thread="blockIdx.x"): |
| for i0_i1_fused_1 in T.thread_binding(24, thread="threadIdx.x"): |
| with T.block("T_reshape"): |
| ax0 = T.axis.spatial(8, (i0_i1_fused_0 * 24 + i0_i1_fused_1) // 3) |
| ax1 = T.axis.spatial(3, (i0_i1_fused_0 * 24 + i0_i1_fused_1) % 3) |
| T.reads( |
| rxplaceholder[ |
| (ax0 * 3 + ax1) // 24, |
| (ax0 * 3 + ax1) % 24 // 12, |
| (ax0 * 3 + ax1) % 12 // 4, |
| (ax0 * 3 + ax1) % 4, |
| ] |
| ) |
| T.writes(T_reshape[ax0, ax1]) |
| T_reshape[ax0, ax1] = rxplaceholder[ |
| (ax0 * 3 + ax1) // 24, |
| (ax0 * 3 + ax1) % 24 // 12, |
| (ax0 * 3 + ax1) % 12 // 4, |
| (ax0 * 3 + ax1) % 4, |
| ] |
| |
| assert has_reshape_pattern(reshape_scheduled) |
| |
| |
| def test_reshape_pattern_expand_dims(): |
| @T.prim_func |
| def expand_dims( |
| rxplaceholder: T.Buffer((2, 3, 4), "float32"), |
| expand_dims: T.Buffer((2, 1, 1, 1, 3, 1, 4, 1), "float32"), |
| ): |
| T.func_attr({"tir.noalias": True}) |
| for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(2, 1, 1, 1, 3, 1, 4, 1): |
| with T.block("expand_dims"): |
| i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1 = T.axis.remap( |
| "SSSSSSSS", [i0, i1, i2, i3, i4, i5, i6, i7] |
| ) |
| T.reads(rxplaceholder[i0_1, i4_1, i6_1]) |
| T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1]) |
| expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1] = rxplaceholder[ |
| i0_1, i4_1, i6_1 |
| ] |
| |
| assert has_reshape_pattern(expand_dims) |
| |
| |
| def test_reshape_pattern_dyn_1(): |
| @T.prim_func |
| def reshape(var_A: T.handle, var_T_reshape: T.handle): |
| n = T.int64() |
| A = T.match_buffer(var_A, (n, T.int64(32), T.int64(128)), "float16") |
| T_reshape = T.match_buffer( |
| var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(128)), "float16" |
| ) |
| for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)): |
| with T.block("T_reshape"): |
| v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) |
| T.reads( |
| A[ |
| ((v_ax3 // T.int64(128) + v_ax2) // T.int64(32) + v_ax0 * n + v_ax1) % n, |
| (v_ax3 // T.int64(128) + v_ax2) % T.int64(32), |
| v_ax3 % T.int64(128), |
| ] |
| ) |
| T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) |
| T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[ |
| ((v_ax3 // T.int64(128) + v_ax2) // T.int64(32) + v_ax0 * n + v_ax1) % n, |
| (v_ax3 // T.int64(128) + v_ax2) % T.int64(32), |
| v_ax3 % T.int64(128), |
| ] |
| |
| assert has_reshape_pattern(reshape) |
| |
| |
| def test_reshape_pattern_dyn_2(): |
| @T.prim_func |
| def reshape(var_A: T.handle, var_T_reshape: T.handle): |
| n = T.int64() |
| A = T.match_buffer(var_A, (T.int64(1), n), "int32") |
| T_reshape = T.match_buffer(var_T_reshape, (n,), "int32") |
| for ax0 in range(n): |
| with T.block("T_reshape"): |
| v_ax0 = T.axis.spatial(n, ax0) |
| T.reads(A[T.int64(0), v_ax0 % n]) |
| T.writes(T_reshape[v_ax0]) |
| T_reshape[v_ax0] = A[T.int64(0), v_ax0 % n] |
| |
| assert has_reshape_pattern(reshape) |
| |
| |
| def test_reshape_pattern_dyn_3(): |
| @T.prim_func |
| def reshape(var_A: T.handle, var_T_reshape: T.handle): |
| T.func_attr({"op_pattern": 8, "tir.noalias": True}) |
| n = T.int64() |
| A = T.match_buffer(var_A, (n, T.int64(4096)), "float16") |
| T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16") |
| for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): |
| with T.block("T_reshape"): |
| v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) |
| T.reads(A[(v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096)]) |
| T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) |
| T_reshape[v_ax0, v_ax1, v_ax2] = A[ |
| (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096) |
| ] |
| |
| assert has_reshape_pattern(reshape) |
| |
| |
| def test_reshape_pattern_dyn_4(): |
| @T.prim_func |
| def reshape(var_A: T.handle, var_T_reshape: T.handle): |
| T.func_attr({"op_pattern": 8, "tir.noalias": True}) |
| n = T.int64() |
| A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") |
| T_reshape = T.match_buffer( |
| var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(128)), "float16" |
| ) |
| for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)): |
| with T.block("T_reshape"): |
| v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) |
| T.reads( |
| A[ |
| T.int64(0), |
| ((v_ax2 * T.int64(128) + v_ax3) // T.int64(4096) + v_ax0 * n + v_ax1) % n, |
| (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096), |
| ] |
| ) |
| T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) |
| T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[ |
| T.int64(0), |
| ((v_ax2 * T.int64(128) + v_ax3) // T.int64(4096) + v_ax0 * n + v_ax1) % n, |
| (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096), |
| ] |
| |
| assert has_reshape_pattern(reshape) |
| |
| |
| def test_reshape_pattern_dyn_5(): |
| @T.prim_func |
| def reshape(var_A: T.handle, var_T_reshape: T.handle): |
| T.func_attr({"op_pattern": 8, "tir.noalias": True}) |
| n = T.int64() |
| A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") |
| T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16") |
| # with T.block("root"): |
| for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): |
| with T.block("T_reshape"): |
| v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) |
| T.reads( |
| A[ |
| T.int64(0), |
| (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, |
| v_ax2 % T.int64(4096) // T.int64(128), |
| v_ax2 % T.int64(128), |
| ] |
| ) |
| T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) |
| T_reshape[v_ax0, v_ax1, v_ax2] = A[ |
| T.int64(0), |
| (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, |
| v_ax2 % T.int64(4096) // T.int64(128), |
| v_ax2 % T.int64(128), |
| ] |
| |
| assert has_reshape_pattern(reshape) |
| |
| |
| def test_reshape_pattern_with_raggedness(): |
| @T.prim_func |
| def reshape_raggedness( |
| A: T.Buffer((100, 768), "float32"), |
| src_indptr: T.Buffer((9,), "int32"), |
| B: T.Buffer((100, 12, 64), "float32"), |
| ): |
| for b in T.serial(8): |
| with T.block("block0"): |
| vb = T.axis.spatial(8, b) |
| for i in T.serial(src_indptr[vb + 1] - src_indptr[vb]): |
| for h in T.serial(12): |
| for f in T.serial(64): |
| with T.block("block1"): |
| vi, vh, vf = T.axis.remap("SSS", [i, h, f]) |
| B[src_indptr[vb] + vi, vh, vf] = A[ |
| src_indptr[vb] + vi, vh * 64 + vf |
| ] |
| |
| assert has_reshape_pattern(reshape_raggedness) |
| |
| |
| def test_reshape_pattern_reject_seqstmt(): |
| @T.prim_func |
| def identity_bias(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")): |
| C = T.alloc_buffer((128, 128), "float32") |
| for i0, i1 in T.grid(4, 4): |
| with T.block("identity"): |
| vi0, vi1 = T.axis.remap("SS", [i0, i1]) |
| C[vi0, vi1] = A[vi0, vi1] |
| for i0, i1 in T.grid(4, 4): |
| with T.block("identity"): |
| vi0, vi1 = T.axis.remap("SS", [i0, i1]) |
| B[vi0, vi1] = C[vi0, vi1] + T.float32(1) |
| |
| @T.prim_func |
| def identity_identity(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")): |
| C = T.alloc_buffer((128, 128), "float32") |
| for i0, i1 in T.grid(4, 4): |
| with T.block("identity"): |
| vi0, vi1 = T.axis.remap("SS", [i0, i1]) |
| C[vi0, vi1] = A[vi0, vi1] |
| for i0, i1 in T.grid(4, 4): |
| with T.block("identity"): |
| vi0, vi1 = T.axis.remap("SS", [i0, i1]) |
| B[vi0, vi1] = C[vi0, vi1] |
| |
| assert not has_reshape_pattern(identity_bias) |
| assert not has_reshape_pattern(identity_identity) |
| |
| |
| def test_reshape_pattern_reject_reduction(): |
| @T.prim_func |
| def reduction(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4,), "float32")): |
| for i0, i1 in T.grid(4, 4): |
| with T.block("identity"): |
| vi0, vi1 = T.axis.remap("SR", [i0, i1]) |
| with T.init(): |
| B[vi0] = T.float32(0) |
| B[vi0] = B[vi0] + A[vi0, vi1] |
| |
| assert not has_reshape_pattern(reduction) |
| |
| |
| def test_reshape_pattern_reject_reduction(): |
| @T.prim_func |
| def reduction(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4,), "float32")): |
| for i0, i1 in T.grid(4, 4): |
| with T.block("identity"): |
| vi0, vi1 = T.axis.remap("SR", [i0, i1]) |
| with T.init(): |
| B[vi0] = T.float32(0) |
| B[vi0] = B[vi0] + A[vi0, vi1] |
| |
| assert not has_reshape_pattern(reduction) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |