| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=missing-function-docstring,missing-module-docstring |
| import tvm |
| import tvm.testing |
| from tvm import tir |
| from tvm.script import tir as T |
| from tvm.tir.schedule.testing import verify_trace_roundtrip |
| import pytest |
| |
| # fmt: off |
| # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks |
| |
| @T.prim_func |
| def single_elementwise(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")): |
| for i, j in T.grid(128, 128): |
| with T.block("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| |
| # fmt: on |
| # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks |
| |
| |
| def test_blockize_outer(): |
| @T.prim_func |
| def after_blockize_outer( |
| A: T.Buffer((128, 128), "float32"), |
| B: T.Buffer((128, 128), "float32"), |
| ) -> None: |
| with T.block("blockized_B"): |
| vio = T.axis.spatial(1, 0) |
| vjo = T.axis.spatial(1, 0) |
| for i, j in T.grid(128, 128): |
| with T.block("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| |
| func = single_elementwise |
| s = tir.Schedule(func, debug_mask="all") |
| x, _ = s.get_loops(s.get_block("B")) |
| s.blockize(x) |
| tvm.ir.assert_structural_equal( |
| s.mod["main"], after_blockize_outer.with_attr("global_symbol", "single_elementwise") |
| ) |
| verify_trace_roundtrip(sch=s, mod=func) |
| |
| |
| def test_blockize_inner(): |
| @T.prim_func |
| def after_blockize_inner( |
| A: T.Buffer((128, 128), "float32"), |
| B: T.Buffer((128, 128), "float32"), |
| ) -> None: |
| for i in T.serial(128): |
| with T.block("blockized_B"): |
| vi = T.axis.spatial(128, i) |
| vjo = T.axis.spatial(1, 0) |
| for j in T.serial(128): |
| with T.block("B"): |
| vj = T.axis.remap("S", [j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| |
| func = single_elementwise |
| s = tir.Schedule(func, debug_mask="all") |
| _, y = s.get_loops(s.get_block("B")) |
| s.blockize(y) |
| tvm.ir.assert_structural_equal( |
| s.mod["main"], after_blockize_inner.with_attr("global_symbol", "single_elementwise") |
| ) |
| verify_trace_roundtrip(sch=s, mod=func) |
| |
| |
| def test_two_elementwise_blockize_reverse_compute_at(): |
| @T.prim_func |
| def before_blockize_rca( |
| A: T.Buffer((128, 128), "float32"), |
| C: T.Buffer((128, 128), "float32"), |
| ) -> None: |
| B = T.alloc_buffer([128, 128], dtype="float32") |
| for i, j in T.grid(8, 8): |
| with T.block("B_o"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) |
| T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) |
| for i_1, j_1 in T.grid(16, 16): |
| with T.block("B"): |
| vi_i, vj_i = T.axis.remap("SS", [i_1, j_1]) |
| T.reads(A[vi * 16 + vi_i, vj * 16 + vj_i]) |
| T.writes(B[vi * 16 + vi_i, vj * 16 + vj_i]) |
| B[vi * 16 + vi_i, vj * 16 + vj_i] = A[vi * 16 + vi_i, vj * 16 + vj_i] * 2.0 |
| for ax0, ax1 in T.grid(16, 16): |
| with T.block("C"): |
| vi = T.axis.spatial(128, i * 16 + ax0) |
| vj = T.axis.spatial(128, j * 16 + ax1) |
| T.reads(B[vi, vj]) |
| T.writes(C[vi, vj]) |
| C[vi, vj] = B[vi, vj] + 1.0 |
| |
| @T.prim_func |
| def after_blockize_rca( |
| A: T.Buffer((128, 128), "float32"), |
| C: T.Buffer((128, 128), "float32"), |
| ) -> None: |
| B = T.alloc_buffer([128, 128], dtype="float32") |
| for i, j in T.grid(8, 8): |
| with T.block("B_o"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) |
| T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) |
| for i_1, j_1 in T.grid(16, 16): |
| with T.block("B"): |
| vi_i, vj_i = T.axis.remap("SS", [i_1, j_1]) |
| T.reads(A[vi * 16 + vi_i, vj * 16 + vj_i]) |
| T.writes(B[vi * 16 + vi_i, vj * 16 + vj_i]) |
| B[vi * 16 + vi_i, vj * 16 + vj_i] = A[vi * 16 + vi_i, vj * 16 + vj_i] * 2.0 |
| with T.block("C_o"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.reads(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) |
| T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) |
| for ax0, ax1 in T.grid(16, 16): |
| with T.block("C"): |
| vi_i, vj_i = T.axis.remap("SS", [ax0, ax1]) |
| T.reads(B[vi * 16 + vi_i, vj * 16 + vj_i]) |
| T.writes(C[vi * 16 + vi_i, vj * 16 + vj_i]) |
| C[vi * 16 + vi_i, vj * 16 + vj_i] = B[vi * 16 + vi_i, vj * 16 + vj_i] + 1.0 |
| |
| func = before_blockize_rca |
| s = tir.Schedule(func, debug_mask="all") |
| _, _, x, _ = s.get_loops(s.get_block("C")) |
| s.blockize(x) |
| tvm.ir.assert_structural_equal( |
| s.mod["main"], after_blockize_rca.with_attr("global_symbol", "before_blockize_rca") |
| ) |
| verify_trace_roundtrip(sch=s, mod=func) |
| |
| |
| def test_two_elementwise_blockize_compute_at(): |
| @T.prim_func |
| def before_blockize_compute_at( |
| A: T.Buffer((128, 128), "float32"), |
| C: T.Buffer((128, 128), "float32"), |
| ) -> None: |
| # body |
| # with T.block("root") |
| B = T.alloc_buffer([128, 128], dtype="float32") |
| for i_0, j_0 in T.grid(8, 8): |
| for ax0, ax1 in T.grid(16, 16): |
| with T.block("B"): |
| vi = T.axis.spatial(128, i_0 * 16 + ax0) |
| vj = T.axis.spatial(128, j_0 * 16 + ax1) |
| T.reads(A[vi, vj]) |
| T.writes(B[vi, vj]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| with T.block("C_o"): |
| vi_o, vj_o = T.axis.remap("SS", [i_0, j_0]) |
| T.reads(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) |
| T.writes(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) |
| for i_1, j_1 in T.grid(16, 16): |
| with T.block("C"): |
| vi_i, vj_i = T.axis.remap("SS", [i_1, j_1]) |
| T.reads(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) |
| T.writes(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) |
| C[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = ( |
| B[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + 1.0 |
| ) |
| |
| @T.prim_func |
| def after_blockize_compute_at( |
| A: T.Buffer((128, 128), "float32"), |
| C: T.Buffer((128, 128), "float32"), |
| ) -> None: |
| B = T.alloc_buffer([128, 128], dtype="float32") |
| for i_0, j_0 in T.grid(8, 8): |
| with T.block("B_o"): |
| vi_o, vj_o = T.axis.remap("SS", [i_0, j_0]) |
| T.reads(A[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) |
| T.writes(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) |
| for ax0, ax1 in T.grid(16, 16): |
| with T.block("B"): |
| vi_i, vj_i = T.axis.remap("SS", [ax0, ax1]) |
| T.reads(A[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) |
| T.writes(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) |
| B[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = ( |
| A[vi_o * 16 + vi_i, vj_o * 16 + vj_i] * 2.0 |
| ) |
| with T.block("C_o"): |
| vi_o, vj_o = T.axis.remap("SS", [i_0, j_0]) |
| T.reads(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) |
| T.writes(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) |
| for i_1, j_1 in T.grid(16, 16): |
| with T.block("C"): |
| vi_i, vj_i = T.axis.remap("SS", [i_1, j_1]) |
| T.reads(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) |
| T.writes(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) |
| C[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = ( |
| B[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + 1.0 |
| ) |
| |
| func = before_blockize_compute_at |
| s = tir.Schedule(func, debug_mask="all") |
| _, _, x, _ = s.get_loops(s.get_block("B")) |
| s.blockize(x) |
| tvm.ir.assert_structural_equal( |
| s.mod["main"], |
| after_blockize_compute_at.with_attr("global_symbol", "before_blockize_compute_at"), |
| ) |
| verify_trace_roundtrip(sch=s, mod=func) |
| |
| |
| def test_blockize_init_loops(): |
| @T.prim_func |
| def rowsum(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")) -> None: |
| for k, i in T.grid(128, 128): |
| with T.block("B"): |
| vk, vi = T.axis.remap("RS", [k, i]) |
| with T.init(): |
| B[vi] = 0.0 |
| B[vi] = B[vi] + A[vi, vk] |
| |
| @T.prim_func |
| def after_rowsum_blockize( |
| A: T.Buffer((128, 128), "float32"), |
| B: T.Buffer((128,), "float32"), |
| ) -> None: |
| with T.block("blockized_B"): |
| vko = T.axis.R(1, 0) |
| vio = T.axis.S(1, 0) |
| with T.init(): |
| for i1 in T.serial(0, 128): |
| with T.block("B_init"): |
| vi_init = T.axis.S(128, i1) |
| B[vi_init] = T.float32(0) |
| for i0, i1_1 in T.grid(128, 128): |
| with T.block("B"): |
| vk, vi = T.axis.remap("RS", [i0, i1_1]) |
| B[vi] = B[vi] + A[vi, vk] |
| |
| s = tir.Schedule(rowsum, debug_mask="all") |
| k, _ = s.get_loops(s.get_block("B")) |
| s.blockize(k) |
| tvm.ir.assert_structural_equal( |
| s.mod["main"], after_rowsum_blockize.with_attr("global_symbol", "rowsum") |
| ) |
| verify_trace_roundtrip(sch=s, mod=rowsum) |
| |
| |
| @pytest.mark.parametrize("preserve_unit_iters", [True, False]) |
| def test_blockize_outer_int64_shape(preserve_unit_iters): |
| @T.prim_func |
| def single_elementwise_int64( |
| A: T.Buffer((T.int64(16), T.int64(128)), "float32"), |
| B: T.Buffer((T.int64(16), T.int64(128)), "float32"), |
| ) -> None: |
| for i0, j0, i1, j1 in T.grid(T.int64(1), T.int64(8), T.int64(16), T.int64(16)): |
| with T.block("B"): |
| vi = T.axis.S(T.int64(16), i0 * T.int64(16) + i1) |
| vj = T.axis.S(T.int64(128), j0 * T.int64(16) + j1) |
| B[vi, vj] = A[vi, vj] + 1.0 |
| |
| @T.prim_func |
| def after_single_elementwise_int64_blockize( |
| A: T.Buffer((T.int64(16), T.int64(128)), "float32"), |
| B: T.Buffer((T.int64(16), T.int64(128)), "float32"), |
| ) -> None: |
| for i0, j0 in T.grid(T.int64(1), T.int64(8)): |
| with T.block("B_o"): |
| vi_o = T.axis.spatial(T.int64(1), T.int64(0)) |
| vj_o = T.axis.spatial(T.int64(8), j0) |
| for i1, j1 in T.grid(T.int64(16), T.int64(16)): |
| with T.block("B"): |
| vi_i, vj_i = T.axis.remap("SS", [i1, j1]) |
| B[vi_i, vj_o * T.int64(16) + vj_i] = A[ |
| vi_i, vj_o * T.int64(16) + vj_i |
| ] + T.float32(1) |
| |
| @T.prim_func |
| def after_single_elementwise_int64_blockize_preserve_unit_iters( |
| A: T.Buffer((T.int64(16), T.int64(128)), "float32"), |
| B: T.Buffer((T.int64(16), T.int64(128)), "float32"), |
| ) -> None: |
| for i0, j0 in T.grid(T.int64(1), T.int64(8)): |
| with T.block("B_o"): |
| vi_o = T.axis.spatial(T.int64(1), i0) |
| vj_o = T.axis.spatial(T.int64(8), j0) |
| for i1, j1 in T.grid(T.int64(16), T.int64(16)): |
| with T.block("B"): |
| vi_i, vj_i = T.axis.remap("SS", [i1, j1]) |
| B[vi_i, vj_o * T.int64(16) + vj_i] = A[ |
| vi_i, vj_o * T.int64(16) + vj_i |
| ] + T.float32(1) |
| |
| s = tir.Schedule(single_elementwise_int64, debug_mask="all") |
| _, _, i1, _ = s.get_loops(s.get_block("B")) |
| s.blockize(i1, preserve_unit_iters=preserve_unit_iters) |
| expected = ( |
| after_single_elementwise_int64_blockize_preserve_unit_iters |
| if preserve_unit_iters |
| else after_single_elementwise_int64_blockize |
| ) |
| tvm.ir.assert_structural_equal( |
| s.mod["main"], expected.with_attr("global_symbol", "single_elementwise_int64") |
| ) |
| verify_trace_roundtrip(sch=s, mod=single_elementwise_int64) |
| |
| |
| def test_blockize_blocks(): |
| @T.prim_func |
| def blocks_func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")) -> None: |
| for m in T.serial(6): |
| for i, j in T.grid(3, 1): |
| with T.block("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.reads(A[vi, vj]) |
| T.writes(B[vi, vj]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| |
| for i, j in T.grid(128, 64): |
| with T.block("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.reads(A[vi, vj + 64]) |
| T.writes(B[vi, vj + 64]) |
| B[vi, vj + 64] = A[vi, vj + 64] * 3.0 |
| |
| @T.prim_func |
| def after_blocks_blockize( |
| A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32") |
| ) -> None: |
| for m in range(6): |
| with T.block("outer_B_C_"): |
| vi_o = T.axis.spatial(1, 0) |
| vj_o = T.axis.spatial(1, 0) |
| T.reads(A[0:128, 0:128]) |
| T.writes(B[0:128, 0:128]) |
| for i, j in T.grid(3, 1): |
| with T.block("B"): |
| vi_i = T.axis.spatial(3, i) |
| T.reads(A[vi_i, 0]) |
| T.writes(B[vi_i, 0]) |
| B[vi_i, 0] = A[vi_i, 0] * T.float32(2) |
| for i, j in T.grid(128, 64): |
| with T.block("C"): |
| vi_i, vj_i = T.axis.remap("SS", [i, j]) |
| T.reads(A[vi_i, vj_i + 64]) |
| T.writes(B[vi_i, vj_i + 64]) |
| B[vi_i, vj_i + 64] = A[vi_i, vj_i + 64] * T.float32(3) |
| |
| s = tir.Schedule(blocks_func, debug_mask="all") |
| blocks = [s.get_block("B"), s.get_block("C")] |
| s.blockize(blocks, preserve_unit_iters=False) |
| expected = after_blocks_blockize |
| tvm.ir.assert_structural_equal( |
| s.mod["main"], expected.with_attr("global_symbol", "blocks_func") |
| ) |
| verify_trace_roundtrip(sch=s, mod=blocks_func) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |