| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import pytest |
| |
| import tvm |
| import tvm.testing |
| from tvm import tir |
| from tvm.ir import Range |
| from tvm.script import tir as T |
| |
| |
| @T.prim_func |
| def func() -> None: |
| A = T.alloc_buffer((128, 128), "float32") |
| B = T.alloc_buffer((128, 128), "float32") |
| C = T.alloc_buffer((128, 128), "float32") |
| D = T.alloc_buffer((128, 128), "float32") |
| with T.block(): |
| # Need add read/write region manually to avoid triggering block access region detector |
| T.reads([B[0, 0], C[0:16, 0:16], A[4:12, 4:12]]) |
| T.writes([A[0:12, 0:12]]) |
| for i, j in T.grid(8, 8): |
| A[i, j] = B[0, 0] + C[0, 0] |
| for i, j in T.grid(2, 2): |
| with T.block(): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8], C[12:16, 12:16]]) |
| T.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]]) |
| for i, j in T.grid(4, 4): |
| A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12] |
| T.evaluate(D.data) |
| |
| |
| @T.prim_func |
| def match_buffer_func() -> None: |
| with T.block("root"): |
| A = T.alloc_buffer((128, 128), "float32") |
| B = T.alloc_buffer((128, 128), "float32") |
| T.reads([]) |
| T.writes([]) |
| # Need add read/write region manually to avoid triggering block access region detector |
| for i, j in T.grid(8, 8): |
| with T.block("block"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) |
| T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) |
| AA = T.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) |
| B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) |
| B1 = T.match_buffer( |
| B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8) |
| ) |
| for ii, jj in T.grid(16, 16): |
| with T.block("AAA"): |
| vii, vjj = T.axis.remap("SS", [ii, jj]) |
| T.reads([]) |
| T.writes(AA[vii, vjj]) |
| AAA = T.match_buffer(AA[vii, vjj], ()) |
| AAA[()] = 1.0 |
| T.evaluate(B0.data) |
| T.evaluate(B1.data) |
| |
| |
| @T.prim_func |
| def opaque_block_func() -> None: |
| with T.block("root"): |
| A = T.alloc_buffer((16, 16), "float32") |
| B = T.alloc_buffer((16, 16), "float32") |
| T.reads([]) |
| T.writes([]) |
| # Need add read/write region manually to avoid triggering block access region detector |
| for i in range(0, 16): |
| with T.block(): |
| T.reads(A[i, 0:16]) |
| T.writes([B[i, 0:16]]) |
| for j in range(0, 16): |
| with T.block(): |
| T.reads(A[i, j]) |
| T.writes(B[i, j]) |
| B[i, j] = A[i, j] + 1.0 |
| |
| |
| @T.prim_func |
| def opaque_access_func() -> None: |
| A = T.alloc_buffer([1024]) |
| B = T.alloc_buffer([1024]) |
| for i in T.serial(0, 8): |
| with T.block(): |
| v = T.axis.S(8, i) |
| T.reads([A[v * 128 : v * 128 + 128]]) |
| T.writes([B[v * 128 : v * 128 + 128]]) |
| T.evaluate( |
| T.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32") |
| ) |
| |
| |
| @T.prim_func |
| def opaque_access_with_tvm_access_ptr_func() -> None: |
| A = T.alloc_buffer([1024]) |
| B = T.alloc_buffer([1024]) |
| C = T.alloc_buffer([1024]) |
| with T.block("opaque"): |
| T.reads(A[0:1024], C[0:1024]) |
| T.writes(B[0:1024], C[0:1024]) |
| T.evaluate(A.access_ptr("r")) |
| T.evaluate(B.access_ptr("w")) |
| T.evaluate(C.access_ptr("rw")) |
| |
| |
| @T.prim_func |
| def access_in_if_then_else_func() -> None: |
| A = T.alloc_buffer([8]) |
| B = T.alloc_buffer([8]) |
| with T.block(): |
| T.reads([A[0:5]]) |
| T.writes([B[0:8]]) |
| for i in T.serial(0, 8): |
| B[i] = T.if_then_else(i < 5, A[i], 0.0, dtype="float32") |
| |
| |
| @T.prim_func |
| def access_in_branch_func() -> None: |
| A = T.alloc_buffer([8]) |
| B = T.alloc_buffer([8]) |
| with T.block(): |
| T.reads([A[0:7]]) |
| T.writes([B[0:8]]) |
| for i in T.serial(0, 8): |
| if i < 5: |
| B[i] = A[i] + 1.0 |
| else: |
| B[i] = A[i - 1] |
| |
| |
| @T.prim_func |
| def gemm() -> None: |
| A = T.alloc_buffer([16, 16], "float32") |
| B = T.alloc_buffer([16, 16], "float32") |
| C = T.alloc_buffer([16, 16], "float32") |
| for i, j, k, ii, jj in T.grid(4, 4, 16, 4, 4): |
| with T.block("update"): |
| vi = T.axis.S(16, i * 4 + ii) |
| vj = T.axis.S(16, j * 4 + jj) |
| vk = T.axis.R(16, k) |
| T.reads(A[vi, vk], B[vj, vk]) |
| T.writes(C[vi, vj]) |
| with T.init(): |
| C[vi, vj] = 0 |
| C[vi, vj] += A[vi, vk] * B[vj, vk] |
| |
| |
| @T.prim_func |
| def decomposed_gemm() -> None: |
| A = T.alloc_buffer([16, 16], "float32") |
| B = T.alloc_buffer([16, 16], "float32") |
| C = T.alloc_buffer([16, 16], "float32") |
| for i, j in T.grid(4, 4): |
| for ii, jj in T.grid(4, 4): |
| with T.block("init"): |
| vi = T.axis.S(16, i * 4 + ii) |
| vj = T.axis.S(16, j * 4 + jj) |
| T.reads([]) |
| T.writes(C[vi, vj]) |
| C[vi, vj] = 0 |
| for k, ii, jj in T.grid(16, 4, 4): |
| with T.block("update"): |
| vi = T.axis.S(16, i * 4 + ii) |
| vj = T.axis.S(16, j * 4 + jj) |
| vk = T.axis.R(16, k) |
| T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) |
| T.writes(C[vi, vj]) |
| C[vi, vj] += A[vi, vk] * B[vj, vk] |
| |
| |
| @T.prim_func |
| def access_of_padding_pattern() -> None: |
| X = T.alloc_buffer([28, 28]) |
| X_pad = T.alloc_buffer([32, 32]) |
| Y = T.alloc_buffer([28, 28]) |
| for i, j in T.grid(32, 32): |
| with T.block("padding"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.reads([X[vi - 2, vj - 2]]) |
| T.writes([X_pad[vi, vj]]) |
| X_pad[vi, vj] = T.if_then_else( |
| 2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32" |
| ) |
| with T.block("padding_reverse"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.reads([X_pad[vi, vj]]) |
| T.writes([Y[vi - 2, vj - 2]]) |
| if 2 <= vi and vi < 30 and 2 <= vj and vj < 30: |
| Y[vi - 2, vj - 2] = X_pad[vi, vj] |
| |
| |
| def test_block_access_region_detector(): |
| block = func.body.block.body.block |
| alloc_buffers = func.body.block.alloc_buffers |
| buffer_var_map = {buf.data: buf for buf in alloc_buffers} |
| ret = tir.analysis.get_block_access_region(block, buffer_var_map) |
| |
| tvm.ir.assert_structural_equal(block.reads, ret[0]) |
| tvm.ir.assert_structural_equal(block.writes, ret[1]) |
| D = alloc_buffers[-1] |
| tvm.ir.assert_structural_equal( |
| [tvm.tir.BufferRegion(D, [Range(0, 128), Range(0, 128)])], ret[2] |
| ) |
| |
| |
| def test_opaque_block(): |
| alloc_buffers = opaque_block_func.body.block.alloc_buffers |
| buffer_var_map = {buf.data: buf for buf in alloc_buffers} |
| |
| block0 = opaque_block_func.body.block.body.body.block |
| ret = tir.analysis.get_block_access_region(block0, buffer_var_map) |
| tvm.ir.assert_structural_equal(block0.reads, ret[0]) |
| tvm.ir.assert_structural_equal(block0.writes, ret[1]) |
| |
| block1 = block0.body.body.block |
| ret = tir.analysis.get_block_access_region(block1, buffer_var_map) |
| tvm.ir.assert_structural_equal(block1.reads, ret[0]) |
| tvm.ir.assert_structural_equal(block1.writes, ret[1]) |
| |
| |
| def test_opaque_access(): |
| block = opaque_access_func.body.block.body.body.block |
| alloc_buffers = opaque_access_func.body.block.alloc_buffers |
| buffer_var_map = {buf.data: buf for buf in alloc_buffers} |
| |
| ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) |
| ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) |
| with pytest.raises(ValueError): |
| tvm.ir.assert_structural_equal(ret0[0], ret1[0]) |
| with pytest.raises(ValueError): |
| tvm.ir.assert_structural_equal(ret0[1], ret1[1]) |
| |
| |
| def test_opaque_access_with_tvm_access_ptr(): |
| block = opaque_access_with_tvm_access_ptr_func.body.block.body.block |
| alloc_buffers = opaque_access_with_tvm_access_ptr_func.body.block.alloc_buffers |
| buffer_var_map = {buf.data: buf for buf in alloc_buffers} |
| |
| ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) |
| ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) |
| tvm.ir.assert_structural_equal(block.reads, ret0[0]) |
| tvm.ir.assert_structural_equal(block.writes, ret0[1]) |
| with pytest.raises(ValueError): |
| tvm.ir.assert_structural_equal(ret0[0], ret1[0]) |
| with pytest.raises(ValueError): |
| tvm.ir.assert_structural_equal(ret0[1], ret1[1]) |
| |
| |
| def test_match_buffer(): |
| root_block = match_buffer_func.body.block |
| block = root_block.body.body.body.block |
| block_inner = block.body[0].body.body.block |
| alloc_buffers = match_buffer_func.body.block.alloc_buffers |
| buffer_var_map = {buf.data: buf for buf in alloc_buffers} |
| |
| # Check block |
| ret = tir.analysis.get_block_access_region(block, buffer_var_map) |
| tvm.ir.assert_structural_equal(block.writes, ret[1]) |
| # B is opaque access |
| tvm.ir.assert_structural_equal(block.reads, ret[2]) |
| |
| # Check inner block AAA without updating buffer_var_map |
| ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) |
| # Since AA is not in the buffer_var_map, region of AA will not be collected. |
| tvm.ir.assert_structural_equal([], ret[1]) |
| |
| # Check inner block AAA |
| for match_buffer in block.match_buffers: |
| target_buffer = match_buffer.buffer |
| buffer_var_map[target_buffer.data] = target_buffer |
| |
| ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) |
| tvm.ir.assert_structural_equal(block_inner.reads, ret[0]) |
| tvm.ir.assert_structural_equal(block_inner.writes, ret[1]) |
| |
| |
| def test_access_in_if_then_else_func(): |
| block = access_in_if_then_else_func.body.block.body.block |
| alloc_buffers = access_in_if_then_else_func.body.block.alloc_buffers |
| buffer_var_map = {buf.data: buf for buf in alloc_buffers} |
| ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) |
| ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) |
| tvm.ir.assert_structural_equal(ret0[0], ret1[0]) |
| tvm.ir.assert_structural_equal(ret0[1], ret1[1]) |
| |
| |
| def test_access_in_branch_func(): |
| block = access_in_branch_func.body.block.body.block |
| alloc_buffers = access_in_branch_func.body.block.alloc_buffers |
| buffer_var_map = {buf.data: buf for buf in alloc_buffers} |
| ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) |
| ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) |
| tvm.ir.assert_structural_equal(ret0[0], ret1[0]) |
| tvm.ir.assert_structural_equal(ret0[1], ret1[1]) |
| |
| |
| def test_access_of_padding_pattern(): |
| s = tvm.tir.schedule.Schedule(access_of_padding_pattern) |
| alloc_buffers = s.get_sref(s.get_block("root")).stmt.alloc_buffers |
| buffer_var_map = {buf.data: buf for buf in alloc_buffers} |
| |
| def do_compare_buffer_region(region, expect): |
| assert region.buffer == expect.buffer |
| analyzer = tvm.arith.Analyzer() |
| for observed_range, expected_range in zip(region.region, expect.region): |
| analyzer.can_prove_equal(observed_range.min, expected_range.min) |
| analyzer.can_prove_equal(observed_range.extent, expected_range.extent) |
| |
| def do_check_block(block_name): |
| block = s.get_sref(s.get_block(block_name)).stmt |
| expect_reads = block.reads |
| expect_writes = block.writes |
| ret = tir.analysis.get_block_access_region(block, buffer_var_map) |
| for i, read in enumerate(ret[0]): |
| do_compare_buffer_region(read, expect_reads[i]) |
| for i, write in enumerate(ret[1]): |
| do_compare_buffer_region(write, expect_writes[i]) |
| |
| do_check_block("padding") |
| do_check_block("padding_reverse") |
| |
| |
| def test_access_of_reduction(): |
| block = gemm.body.block.body.body.body.body.body.body.block |
| alloc_buffers = gemm.body.block.alloc_buffers |
| buffer_var_map = {buf.data: buf for buf in alloc_buffers} |
| ret = tir.analysis.get_block_access_region(block, buffer_var_map) |
| tvm.ir.assert_structural_equal(block.reads, ret[0]) |
| tvm.ir.assert_structural_equal(block.writes, ret[1]) |
| |
| |
| def test_access_of_decompose_reduction(): |
| init = decomposed_gemm.body.block.body.body.body[0].body.body.block |
| update = decomposed_gemm.body.block.body.body.body[1].body.body.body.block |
| alloc_buffers = decomposed_gemm.body.block.alloc_buffers |
| buffer_var_map = {buf.data: buf for buf in alloc_buffers} |
| for block in [init, update]: |
| ret = tir.analysis.get_block_access_region(block, buffer_var_map) |
| tvm.ir.assert_structural_equal(block.reads, ret[0]) |
| tvm.ir.assert_structural_equal(block.writes, ret[1]) |
| |
| |
| def test_buffer_access_with_let_binding(): |
| @T.prim_func |
| def func( |
| storage: T.Buffer((16, 16, 16), "float32"), |
| seq_slot_ids: T.Buffer((16,), "int32"), |
| history_slot_ids: T.Buffer((16,), "int32"), |
| output: T.Buffer((16, 16), "float32"), |
| ): |
| for i, s in T.grid(16, 16): |
| with T.block("copy"): |
| vi, vs = T.axis.remap("SS", [i, s]) |
| T.reads( |
| seq_slot_ids[vi], |
| history_slot_ids[vi], |
| storage[seq_slot_ids[vi], history_slot_ids[vi], vs], |
| ) |
| T.writes(output[vi, vs]) |
| seq_id: T.int32 = seq_slot_ids[vi] |
| history_id: T.int32 = history_slot_ids[vi] |
| output[vi, vs] = storage[seq_id, history_id, vs] |
| |
| block = func.body.block.body.body.body.block |
| buffer_var_map = {buf.data: buf for buf in func.buffer_map.values()} |
| ret = tir.analysis.get_block_access_region(block, buffer_var_map) |
| tvm.ir.assert_structural_equal(block.reads, ret[0]) |
| tvm.ir.assert_structural_equal(block.writes, ret[1]) |
| |
| |
| def test_buffer_access_with_nested_let_binding(): |
| @T.prim_func |
| def func( |
| A: T.Buffer((16, 16), "float32"), |
| B: T.Buffer((16, 16), "float32"), |
| C: T.Buffer((16, 16), "float32"), |
| ): |
| for i, s in T.grid(16, 16): |
| with T.block("copy"): |
| vi, vs = T.axis.remap("SS", [i, s]) |
| T.reads(A[vi, vs], B[vi, vs]) |
| T.writes(C[vi, vs]) |
| vi1: T.int32 = vi |
| vi2: T.int32 = vi1 |
| vs1: T.int32 = vs |
| vs2: T.int32 = vs1 |
| vs3: T.int32 = vs2 |
| C[vi, vs1] = A[vi1, vs2] + B[vi2, vs3] |
| |
| block = func.body.block.body.body.body.block |
| buffer_var_map = {buf.data: buf for buf in func.buffer_map.values()} |
| ret = tir.analysis.get_block_access_region(block, buffer_var_map) |
| tvm.ir.assert_structural_equal(block.reads, ret[0]) |
| tvm.ir.assert_structural_equal(block.writes, ret[1]) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |