| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=missing-function-docstring,missing-module-docstring |
| import pytest |
| import tvm |
| import tvm.testing |
| import tvm.s_tir.tensor_intrin |
| from tvm import tir |
| from tvm.script import tir as T |
| from tvm.s_tir.schedule.testing import ( |
| verify_trace_roundtrip, |
| assert_structural_equal_ignore_global_symbol, |
| ) |
| |
| # pylint: disable=no-member,invalid-name,unused-variable |
| |
| |
| @T.prim_func |
| def elementwise(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| B = T.alloc_buffer((128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = B[vi, vj] + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_multi_producer_consumer(a: T.handle, c: T.handle, d: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| B = T.alloc_buffer((128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| D = T.match_buffer(d, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 # B has two consumers |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = B[vi, vj] + 1.0 |
| for i, j in T.grid(128, 128): |
| with T.sblock("D"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| D[vi, vj] = B[vi, vj] + 2.0 + C[vi, vj] # D has two producers |
| |
| |
| @T.prim_func |
| def elementwise_multi_consumer_inlined(a: T.handle, c: T.handle, d: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| D = T.match_buffer(d, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = A[vi, vj] * 2.0 + 1.0 |
| for i, j in T.grid(128, 128): |
| with T.sblock("D"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| D[vi, vj] = A[vi, vj] * 2.0 + 2.0 + C[vi, vj] |
| |
| |
| @T.prim_func |
| def elementwise_standalone(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| B = T.alloc_buffer((128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = A[vi, vj] + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_standalone_dce(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = A[vi, vj] + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_under_loop(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| B = T.alloc_buffer((128, 128)) |
| for i in T.serial(0, 128): |
| for j in T.serial(0, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for j in T.serial(0, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = B[vi, vj] + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_inlined(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = A[vi, vj] * 2.0 + 1.0 |
| |
| |
| @T.prim_func |
| def fail_multi_reader_writer(a: T.handle, d: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| B = T.alloc_buffer((128, 128)) |
| C = T.alloc_buffer((128, 128)) |
| D = T.match_buffer(d, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| C[vi, vj] = A[vi, vj] + 2.0 |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| D[vi, vj] = B[vi, vj] + C[vi, vj] |
| |
| |
| @T.prim_func |
| def elementwise_multi_reverse_loads(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| B = T.alloc_buffer((128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = (B[vi, vj] + 1.0) * (B[vi, vj] * 2.0) + 3.0 |
| |
| |
| @T.prim_func |
| def elementwise_multi_reverse_loads_inlined(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0 |
| |
| |
| @T.prim_func |
| def elementwise_reverse_affine_load( |
| A: T.Buffer((128, 128), "float32"), C: T.Buffer((8, 32, 8, 8), "float32") |
| ) -> None: |
| B = T.alloc_buffer((128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j, k, l in T.grid(8, 32, 8, 8): |
| with T.sblock("C"): |
| vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) |
| C[vi, vj, vk, vl] = B[ |
| ((((vi * 32) + vj) * 8 + vk) * 8 + vl) // 128, |
| ((((vi * 32) + vj) * 8 + vk) * 8 + vl) % 128, |
| ] |
| |
| |
| @T.prim_func |
| def elementwise_reverse_affine_load_inlined( |
| A: T.Buffer((128, 128), "float32"), C: T.Buffer((8, 32, 8, 8), "float32") |
| ) -> None: |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[ |
| (vj + vi * 128) // 2048, |
| (vj + vi * 128) // 64 % 32, |
| ((vj + vi * 128) // 8) % 8, |
| (vj + vi * 128) % 8, |
| ] = ( |
| A[vi, vj] * 2.0 |
| ) |
| |
| |
| @T.prim_func |
| def elementwise_reverse_affine_load_unit_iter( |
| A: T.Buffer((128, 128), "float32"), |
| B: T.Buffer((8, 16, 1), "float32"), |
| D: T.Buffer((1, 8, 16, 128), "float32"), |
| ) -> None: |
| C = T.alloc_buffer((128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = A[vi, vj] * 2.0 |
| for i, j, k, l in T.grid(1, 8, 16, 128): |
| with T.sblock("C"): |
| vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) |
| D[vi, vj, vk, vl] = C[vj * 16 + vk, vl] + B[vj, vk, vi] |
| |
| |
| @T.prim_func |
| def elementwise_reverse_affine_load_unit_iter_inlined( |
| A: T.Buffer((128, 128), "float32"), |
| B: T.Buffer((8, 16, 1), "float32"), |
| D: T.Buffer((1, 8, 16, 128), "float32"), |
| ) -> None: |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| D[0, vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + B[vi // 16, vi % 16, 0] |
| |
| |
| @T.prim_func |
| def elementwise_reverse_affine_load_unit_iter_simplified( |
| A: T.Buffer((128, 128), "float32"), |
| B: T.Buffer((8, 16, 1), "float32"), |
| D: T.Buffer((1, 8, 16, 128), "float32"), |
| ) -> None: |
| C = T.alloc_buffer((128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = A[vi, vj] * 2.0 |
| for i, j, k in T.grid(8, 16, 128): |
| with T.sblock("C"): |
| vi, vj, vk = T.axis.remap("SSS", [i, j, k]) |
| D[0, vi, vj, vk] = C[vi * 16 + vj, vk] + B[vi, vj, 0] |
| |
| |
| @T.prim_func |
| def elementwise_reverse_affine_load_unit_iter_simplified_inlined( |
| A: T.Buffer((128, 128), "float32"), |
| B: T.Buffer((8, 16, 1), "float32"), |
| D: T.Buffer((1, 8, 16, 128), "float32"), |
| ) -> None: |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| D[0, vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + B[vi // 16, vi % 16, 0] |
| |
| |
| @T.prim_func |
| def elementwise_reverse_affine_chain( |
| A: T.Buffer((128, 128), "float32"), D: T.Buffer((1, 8, 16, 128), "float32") |
| ): |
| B = T.alloc_buffer((128, 128)) |
| C = T.alloc_buffer((8, 16, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j, k in T.grid(8, 16, 128): |
| with T.sblock("C"): |
| vi, vj, vk = T.axis.remap("SSS", [i, j, k]) |
| C[vi, vj, vk] = B[vi * 16 + vj, vk] + 1.0 |
| for i, j, k, l in T.grid(1, 8, 16, 128): |
| with T.sblock("D"): |
| vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) |
| D[vi, vj, vk, vl] = C[vj, vk, vl] |
| |
| |
| @T.prim_func |
| def elementwise_reverse_affine_chain_inlined( |
| A: T.Buffer((128, 128), "float32"), D: T.Buffer((1, 8, 16, 128), "float32") |
| ) -> None: |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| D[0, vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_multi_reverse_affine_load( |
| A: T.Buffer((128, 128), "float32"), |
| C: T.Buffer((8, 16, 128), "float32"), |
| ) -> None: |
| B = T.alloc_buffer((128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j, k in T.grid(8, 16, 128): |
| with T.sblock("C"): |
| vi, vj, vk = T.axis.remap("SSS", [i, j, k]) |
| C[vi, vj, vk] = B[vi * 16 + vj, vk] + B[vi * 16 + vj, vk] |
| |
| |
| @T.prim_func |
| def elementwise_multi_reverse_affine_load_inlined( |
| A: T.Buffer((128, 128), "float32"), |
| C: T.Buffer((8, 16, 128), "float32"), |
| ) -> None: |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + A[vi, vj] * 2.0 |
| |
| |
| @T.prim_func |
| def elementwise_reverse_non_affine_load( |
| A: T.Buffer((128, 128), "float32"), C: T.Buffer((8, 16, 128), "float32") |
| ) -> None: |
| B = T.alloc_buffer((128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j, k in T.grid(8, 16, 128): |
| with T.sblock("C"): |
| vi, vj, vk = T.axis.remap("SSS", [i, j, k]) |
| C[vi, vj, vk] = B[vi * 16 + vj, vi * 16 + vj] |
| |
| |
| @T.prim_func |
| def opaque_access_load(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| B = T.alloc_buffer((128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.reads(B[0:128, 0:128]) |
| T.writes(C[0:128, 0:128]) |
| T.evaluate(B.access_ptr("r", extent=128)) |
| C[vi, vj] = B[vi, vj] + 1.0 |
| |
| |
| @T.prim_func |
| def opaque_access_store(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| B = T.alloc_buffer((128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.reads(B[0:128, 0:128]) |
| T.writes(C[0:128, 0:128]) |
| T.evaluate(B.access_ptr("r", extent=128)) |
| T.evaluate(C.access_ptr("w", extent=128)) |
| C[vi, vj] = B[vi, vj] + 1.0 |
| |
| |
| @T.prim_func |
| def buffer_matched(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| B = T.alloc_buffer((128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| Bb = T.match_buffer(B[vi : vi + 1, vj], (1, 1)) |
| C[vi, vj] = Bb[0, 0] + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_predicate(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| B = T.alloc_buffer((128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.where(B[i, j] < 10.0) |
| C[vi, vj] = B[vi, vj] + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_predicate_inlined(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.where(A[i, j] * 2.0 < 10.0) |
| C[vi, vj] = A[vi, vj] * 2.0 + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_multi_loads(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| B = T.alloc_buffer((128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = B[vi, vj] + B[vi, vj + 1] + B[vi, vj + 2] |
| |
| |
| @T.prim_func |
| def elementwise_multi_loads_inlined(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2] * 2.0 |
| |
| |
| @T.prim_func |
| def access_opaque_ptr_then_elemwise(a: T.handle, b: T.handle) -> None: |
| A = T.match_buffer(a, [1024]) |
| B = T.match_buffer(b, [1024]) |
| A_cache = T.alloc_buffer([1024]) |
| BB = T.alloc_buffer([1024]) |
| with T.sblock("opaque"): |
| # annotated opaque partial access |
| T.reads(A[0:512]) |
| T.writes(A_cache[0:512]) |
| T.evaluate(A.access_ptr("r", extent=512)) |
| T.evaluate(A_cache.access_ptr("w", extent=512)) |
| for i in range(512): |
| with T.sblock("BB"): |
| vi = T.axis.remap("S", [i]) |
| BB[vi] = A_cache[vi] * 2.0 |
| for i in range(512): |
| with T.sblock("B"): |
| vi = T.axis.remap("S", [i]) |
| B[vi] = BB[vi] + 1.0 |
| |
| |
| @T.prim_func |
| def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None: |
| A = T.match_buffer(a, [1024], dtype="float32") |
| B = T.match_buffer(b, [1024], dtype="float32") |
| A_cache = T.alloc_buffer([1024], dtype="float32") |
| with T.sblock("opaque"): |
| # annotated opaque partial access should be kept |
| T.reads(A[0:512]) |
| T.writes([A_cache[0:512]]) |
| T.evaluate(A.access_ptr("r", extent=512)) |
| T.evaluate(A_cache.access_ptr("w", extent=512)) |
| for i in T.serial(0, 512): |
| with T.sblock("B"): |
| vi = T.axis.spatial(512, i) |
| T.reads([A_cache[vi]]) |
| T.writes([B[vi]]) |
| B[vi] = A_cache[vi] * 2.0 + 1.0 |
| |
| |
| @T.prim_func |
| def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None: |
| A = T.match_buffer(var_A, [512, 512], dtype="float32") |
| B = T.match_buffer(var_B, [512, 512], dtype="float32") |
| compute = T.match_buffer(var_compute, [512, 512], dtype="float32") |
| C = T.alloc_buffer([512, 512], dtype="float32") |
| for i0, i1, i2 in T.grid(512, 512, 512): |
| with T.sblock("C"): |
| i, j, k = T.axis.remap("SSR", [i0, i1, i2]) |
| T.reads([C[i, j], A[i, k], B[k, j]]) |
| T.writes([C[i, j]]) |
| with T.init(): |
| C[i, j] = T.float32(0) |
| C[i, j] = C[i, j] + A[i, k] * B[k, j] |
| for i0, i1 in T.grid(512, 512): |
| with T.sblock("compute"): |
| i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) |
| T.reads([C[i0_1, i1_1]]) |
| T.writes([compute[i0_1, i1_1]]) |
| compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) |
| |
| |
| @T.prim_func |
| def elementwise_output(a: T.handle, b: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| B = T.match_buffer(b, (128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j in T.grid(128, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = B[vi, vj] + 1.0 |
| |
| |
| @T.prim_func |
| def inline_block_with_init( |
| A: T.Buffer((1, 512, 7, 7), "float32"), |
| B: T.Buffer((1, 512, 1, 1), "float32"), |
| ) -> None: |
| B_rf = T.alloc_buffer([1, 512, 1, 1, 49], dtype="float32") |
| for i0, i1, i2, i3, i4, i5 in T.grid(1, 512, 1, 1, 49, 1): |
| with T.sblock("tensor_rf"): |
| vi4 = T.axis.spatial(49, i4) |
| ax0 = T.axis.spatial(1, 0) |
| ax1 = T.axis.spatial(512, i1) |
| ax2 = T.axis.spatial(1, 0) |
| ax3 = T.axis.spatial(1, 0) |
| with T.init(): |
| B_rf[ax0, ax1, ax2, ax3, vi4] = T.float32(0) |
| B_rf[ax0, ax1, ax2, ax3, vi4] = ( |
| B_rf[ax0, ax1, ax2, ax3, vi4] |
| + A[ |
| ax0, |
| ax1, |
| ax2 * 7 + vi4 // 7, |
| ax3 * 7 + vi4 % 7, |
| ] |
| ) |
| for i0, i1 in T.grid(1, 512): |
| for ax0, ax1, ax2, ax3, ax4 in T.grid(49, 1, 1, 1, 1): |
| with T.sblock("tensor"): |
| vi4, ax0_1 = T.axis.remap("RS", [ax0, ax1]) |
| ax1_1 = T.axis.spatial(512, i1 + ax2) |
| ax2_1, ax3_1 = T.axis.remap("SS", [ax3, ax4]) |
| with T.init(): |
| B[ax0_1, ax1_1, ax2_1, ax3_1] = T.float32(0) |
| B[ax0_1, ax1_1, ax2_1, ax3_1] = ( |
| B[ax0_1, ax1_1, ax2_1, ax3_1] + B_rf[ax0_1, ax1_1, ax2_1, ax3_1, vi4] |
| ) |
| |
| |
| @T.prim_func |
| def exp_exp_opaque_access_with_tvm_access_ptr( |
| lookup_table: T.Buffer((1024,), "int8"), |
| x: T.Buffer((16,), "float16"), |
| compute: T.Buffer((16,), "float16"), |
| ) -> None: |
| compute_1 = T.alloc_buffer([16], dtype="float16") |
| for i0 in T.serial(16): |
| with T.sblock("compute"): |
| i0_1 = T.axis.spatial(16, i0) |
| T.reads(x[i0_1]) |
| T.writes(compute_1[i0_1]) |
| compute_1[i0_1] = T.exp(x[i0_1], dtype="float16") |
| for i0 in T.serial(16): |
| with T.sblock("compute_1"): |
| i0_2 = T.axis.spatial(16, i0) |
| T.reads(lookup_table[0:1024], compute_1[i0_2]) |
| T.writes(compute[i0_2]) |
| T.evaluate(lookup_table.access_ptr("r")) |
| compute[i0_2] = T.exp( |
| compute_1[i0_2], |
| dtype="float16", |
| ) |
| |
| |
| @T.prim_func |
| def exp_exp_opaque_access_with_tvm_access_ptr_inlined( |
| lookup_table: T.Buffer((1024,), "int8"), |
| x: T.Buffer((16,), "float16"), |
| compute: T.Buffer((16,), "float16"), |
| ) -> None: |
| for i0 in T.serial(16): |
| with T.sblock("compute_1"): |
| i0_1 = T.axis.spatial(16, i0) |
| # Do not put the opaque access to new write region when opaque access |
| # wrapped with a tvm_access_ptr and the access mask set to "read only" |
| T.reads(lookup_table[0:1024], x[i0_1]) |
| T.writes(compute[i0_1]) |
| T.evaluate(lookup_table.access_ptr("r")) |
| compute[i0_1] = T.exp( |
| T.exp(x[i0_1], dtype="float16"), |
| dtype="float16", |
| ) |
| |
| |
| @T.prim_func |
| def elementwise_overcomputed_producer( |
| A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32") |
| ) -> None: |
| B = T.alloc_buffer((128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j in T.grid(127, 127): |
| with T.sblock("C"): |
| cvi, cvj = T.axis.remap("SS", [i, j]) |
| C[cvi, cvj] = B[cvi, cvj] + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_overcomputed_producer_reverse_inlined( |
| A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32") |
| ) -> None: |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.where(i < 127 and j < 127) |
| C[vi, vj] = A[vi, vj] * 2.0 + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_overcomputed_producer_simplify_predicate( |
| A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32") |
| ) -> None: |
| B = T.alloc_buffer((128, 128)) |
| for i in T.grid(16384): |
| with T.sblock("B"): |
| vi = T.axis.spatial(128, i // 128) |
| vj = T.axis.spatial(128, i % 128) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j in T.grid(127, 127): |
| with T.sblock("C"): |
| cvi, cvj = T.axis.remap("SS", [i, j]) |
| C[cvi, cvj] = B[cvi, cvj] + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_overcomputed_producer_simplify_predicate_reverse_inlined( |
| A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32") |
| ) -> None: |
| for i in T.grid(16384): |
| with T.sblock("B"): |
| vi = T.axis.spatial(128, i // 128) |
| vj = T.axis.spatial(128, i % 128) |
| T.where(i < 16255 and i % 128 < 127) |
| C[vi, vj] = A[vi, vj] * 2.0 + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_overcomputed_producer_injective_load( |
| A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32") |
| ) -> None: |
| B = T.alloc_buffer((8, 8, 16, 16)) |
| for i0, j0, i1, j1 in T.grid(8, 8, 16, 16): |
| with T.sblock("B"): |
| vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1]) |
| B[vi, vj, vm, vn] = A[vi * 16 + vm, vj * 16 + vn] * 2.0 |
| for i, j in T.grid(127, 127): |
| with T.sblock("C"): |
| cvi, cvj = T.axis.remap("SS", [i, j]) |
| C[cvi, cvj] = B[cvi // 16, cvj // 16, cvi % 16, cvj % 16] + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_overcomputed_producer_injective_load_reverse_inlined( |
| A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32") |
| ) -> None: |
| for i0, j0, i1, j1 in T.grid(8, 8, 16, 16): |
| with T.sblock("B"): |
| vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1]) |
| T.where(i0 * 16 + i1 < 127 and j0 * 16 + j1 < 127) |
| C[vm + vi * 16, vn + vj * 16] = A[vi * 16 + vm, vj * 16 + vn] * 2.0 + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_producer_not_cover_consumer( |
| A: T.Buffer((128, 128), "float32"), D: T.Buffer((256, 128), "float32") |
| ) -> None: |
| B = T.alloc_buffer((128, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j in T.grid(256, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| D[vi, vj] = T.if_then_else(vi >= 128, B[vi - 128, vj], T.float32(0), dtype="float32") |
| |
| |
| @T.prim_func |
| def elementwise_producer_is_reduction( |
| A: T.Buffer((128, 128), "float32"), D: T.Buffer((128), "float32") |
| ) -> None: |
| B = T.alloc_buffer((128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SR", [i, j]) |
| with T.init(): |
| B[vi] = T.float32(0) |
| B[vi] = B[vi] + A[vi, vj] |
| for i in T.grid(128): |
| with T.sblock("C"): |
| vi = T.axis.remap("S", [i]) |
| D[vi] = B[vi] + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_predicate_producer(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| B = T.alloc_buffer((127, 128)) |
| C = T.match_buffer(c, (127, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.where(i < 127) |
| B[vi, vj] = A[vi, vj] * 2.0 |
| for i, j in T.grid(127, 128): |
| with T.sblock("C"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| C[vi, vj] = B[vi, vj] + 1.0 |
| |
| |
| @T.prim_func |
| def elementwise_predicate_producer_inlined(a: T.handle, c: T.handle) -> None: |
| A = T.match_buffer(a, (128, 128)) |
| C = T.match_buffer(c, (127, 128)) |
| for i, j in T.grid(128, 128): |
| with T.sblock("B"): |
| T.where(i < 127) |
| vi, vj = T.axis.remap("SS", [i, j]) |
| T.reads(A[vi, vj]) |
| T.writes(C[vi, vj]) |
| C[vi, vj] = A[vi, vj] * T.float32(2) + T.float32(1) |
| |
| |
| # fmt: off |
| @tvm.script.ir_module |
| class Conv2dInt8_TensorCore_with_predicate_before: |
| @T.prim_func |
| def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer(256, "int32"), p5: T.Buffer(256, "int32"), p6: T.Buffer(256, "int32"), p7: T.Buffer((), "int32"), p8: T.Buffer(1, "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "int32")): |
| # function attr dict |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| # body |
| with T.sblock("root"): |
| T.reads() |
| T.writes() |
| T.sblock_attr({"meta_schedule.unroll_explicit":1024}) |
| compute_3 = T.alloc_buffer([16, 56, 56, 256], dtype="int32") |
| conv2d_nhwc_reindex_shared = T.alloc_buffer([50176, 256], dtype="int32", scope="shared") |
| conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([50176, 256], dtype="int32", scope="wmma.accumulator") |
| pad_temp_reindex_shared = T.alloc_buffer([50176, 64], dtype="int8", scope="shared") |
| p1_reindex_shared = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="shared") |
| pad_temp_reindex_shared_wmma_matrix_a = T.alloc_buffer([50176, 64], dtype="int8", scope="wmma.matrix_a") |
| p1_reindex_shared_wmma_matrix_b = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="wmma.matrix_b") |
| for ax2_0_0_ax3_0_0_fused in T.thread_binding(32, thread="blockIdx.y"): |
| for ax2_0_1_ax3_0_1_fused in T.thread_binding(196, thread="blockIdx.x"): |
| for ax2_0_2_ax3_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): |
| for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 2): |
| for ax0_ax1_fused in T.serial(1024): |
| with T.sblock("pad_temp_reindex_shared"): |
| v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0_ax1_fused // 32) |
| v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32) |
| T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) |
| T.writes(pad_temp_reindex_shared[v0, v1]) |
| T.sblock_attr({"buffer_dim_align":[[0, 0, 32, 16]], "meta_schedule.cooperative_fetch":4}) |
| pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] |
| for ax0_ax1_ax2_ax3_fused in T.serial(2048): |
| with T.sblock("p1_reindex_shared"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial(1, 0) |
| v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + ax0_ax1_ax2_ax3_fused // 32) |
| v3 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused % 32) |
| T.reads(p1[v2, v0, v1, v3]) |
| T.writes(p1_reindex_shared[v0, v1, v2, v3]) |
| T.sblock_attr({"buffer_dim_align":[[0, 2, 32, 16]], "meta_schedule.cooperative_fetch":3}) |
| p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] |
| for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 2): |
| for ax0_0_1, ax1_0_1 in T.grid(1, 1): |
| with T.sblock("pad_temp_reindex_shared_wmma.matrix_a_o"): |
| v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) |
| v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) |
| T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) |
| T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) |
| T.sblock_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a_shared"}) |
| for ax0_1_1, ax1_1_1 in T.grid(16, 16): |
| with T.sblock("pad_temp_reindex_shared_wmma.matrix_a"): |
| v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) |
| T.reads(pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) |
| T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) |
| pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] |
| for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 1): |
| with T.sblock("p1_reindex_shared_wmma.matrix_b_o"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial(1, 0) |
| v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax2_0) |
| v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) |
| T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) |
| T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) |
| T.sblock_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans_shared"}) |
| for ax2_1, ax3_1 in T.grid(16, 16): |
| with T.sblock("p1_reindex_shared_wmma.matrix_b"): |
| v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) |
| T.reads(p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) |
| T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) |
| p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] |
| for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 1, 1, 2): |
| with T.sblock("conv2d_nhwc_o"): |
| v0 = T.axis.reduce(1, 0) |
| v1 = T.axis.reduce(1, 0) |
| v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) |
| v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax3_0_3 * 2 + ax3_0_4) |
| v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 + ax4_0_2) |
| T.reads(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16]) |
| T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) |
| T.sblock_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_s8s8s32_trans", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_s32", "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) |
| with T.init(): |
| for ax2_1, ax3_1 in T.grid(16, 16): |
| with T.sblock("conv2d_nhwc_init"): |
| v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) |
| T.reads() |
| T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init]) |
| conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = 0 |
| for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): |
| with T.sblock("conv2d_nhwc"): |
| v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) |
| T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) |
| T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) |
| T.sblock_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) |
| conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.cast(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], "int32") * T.cast(p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i], "int32") |
| for ax0_0, ax1_0 in T.grid(1, 2): |
| with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): |
| v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) |
| v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax1_0) |
| T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) |
| T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) |
| T.sblock_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_s32_shared"}) |
| for ax0_1, ax1_1 in T.grid(16, 16): |
| with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator"): |
| v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) |
| T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) |
| T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) |
| conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] |
| for ax0, ax1_0, ax1_1, ax1_2, ax1_3 in T.grid(32, 1, 4, 32, 2): |
| with T.sblock("conv2d_nhwc_reindex_shared"): |
| T.where(((ax1_0 * 4 + ax1_1) * 32 + ax1_2) * 2 + ax1_3 < 64) |
| v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0) |
| v1 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + (ax1_0 * 256 + ax1_1 * 64 + ax1_2 * 2 + ax1_3)) |
| T.reads(p7[()], conv2d_nhwc_reindex_shared[v0, v1], p2[0, 0, 0, v1], p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], p8[0]) |
| T.writes(compute_3[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) |
| compute_3[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] = T.q_multiply_shift(T.max(T.min(p7[()] + T.q_multiply_shift_per_axis(conv2d_nhwc_reindex_shared[v0, v1] - p2[0, 0, 0, v1] + p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], 31, False, True, dtype="int32"), 255), 0) - p8[0], 1457846997, 31, 0, dtype="int32") |
| for i0_12, i1_12, i2_12, i3_12 in T.grid(16, 56, 56, 256): |
| with T.sblock("compute_4"): |
| i0_13, i1_13, i2_13, i3_13 = T.axis.remap("SSSS", [i0_12, i1_12, i2_12, i3_12]) |
| T.reads(compute_3[i0_13, i1_13, i2_13, i3_13], p9[i0_13, i1_13, i2_13, i3_13]) |
| T.writes(compute[i0_13, i1_13, i2_13, i3_13]) |
| compute[i0_13, i1_13, i2_13, i3_13] = T.max(T.min(compute_3[i0_13, i1_13, i2_13, i3_13] + T.q_multiply_shift(p9[i0_13, i1_13, i2_13, i3_13], 2101000910, 31, 0, dtype="int32"), 255), 0) |
| |
| @tvm.script.ir_module |
| class Conv2dInt8_TensorCore_with_predicate_after: |
| @T.prim_func |
| def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer((256,), "int32"), p5: T.Buffer((256,), "int32"), p6: T.Buffer((256,), "int32"), p7: T.Buffer((), "int32"), p8: T.Buffer((1,), "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "int32")): |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| with T.sblock("root"): |
| T.reads() |
| T.writes() |
| T.sblock_attr({"meta_schedule.unroll_explicit": 1024}) |
| conv2d_nhwc_reindex_shared = T.alloc_buffer((50176, 256), "int32", scope="shared") |
| conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((50176, 256), "int32", scope="wmma.accumulator") |
| pad_temp_reindex_shared = T.alloc_buffer((50176, 64), "int8", scope="shared") |
| p1_reindex_shared = T.alloc_buffer((1, 1, 256, 64), "int8", scope="shared") |
| pad_temp_reindex_shared_wmma_matrix_a = T.alloc_buffer((50176, 64), "int8", scope="wmma.matrix_a") |
| p1_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 256, 64), "int8", scope="wmma.matrix_b") |
| for ax2_0_0_ax3_0_0_fused in T.thread_binding(32, thread="blockIdx.y"): |
| for ax2_0_1_ax3_0_1_fused in T.thread_binding(196, thread="blockIdx.x"): |
| for ax2_0_2_ax3_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): |
| for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 2): |
| for ax0_ax1_fused in range(1024): |
| with T.sblock("pad_temp_reindex_shared"): |
| v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0_ax1_fused // 32) |
| v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32) |
| T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) |
| T.writes(pad_temp_reindex_shared[v0, v1]) |
| T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 16]], "meta_schedule.cooperative_fetch": 4}) |
| pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] |
| for ax0_ax1_ax2_ax3_fused in range(2048): |
| with T.sblock("p1_reindex_shared"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial(1, 0) |
| v2 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + ax0_ax1_ax2_ax3_fused // 32) |
| v3 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused % 32) |
| T.reads(p1[v2, v0, v1, v3]) |
| T.writes(p1_reindex_shared[v0, v1, v2, v3]) |
| T.sblock_attr({"buffer_dim_align": [[0, 2, 32, 16]], "meta_schedule.cooperative_fetch": 3}) |
| p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] |
| for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 2): |
| for ax0_0_1, ax1_0_1 in T.grid(1, 1): |
| with T.sblock("pad_temp_reindex_shared_wmma.matrix_a_o"): |
| v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) |
| v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) |
| T.reads(pad_temp_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) |
| T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) |
| T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_s8_a_shared"}) |
| for ax0_1_1, ax1_1_1 in T.grid(16, 16): |
| with T.sblock("pad_temp_reindex_shared_wmma.matrix_a"): |
| v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) |
| T.reads(pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) |
| T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) |
| pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] |
| for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 1): |
| with T.sblock("p1_reindex_shared_wmma.matrix_b_o"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial(1, 0) |
| v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax2_0) |
| v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1) |
| T.reads(p1_reindex_shared[v0, v1, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) |
| T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) |
| T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_s8_b_trans_shared"}) |
| for ax2_1, ax3_1 in T.grid(16, 16): |
| with T.sblock("p1_reindex_shared_wmma.matrix_b"): |
| v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) |
| T.reads(p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) |
| T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) |
| p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] |
| for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 1, 1, 2): |
| with T.sblock("conv2d_nhwc_o"): |
| v0 = T.axis.reduce(1, 0) |
| v1 = T.axis.reduce(1, 0) |
| v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) |
| v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax3_0_3 * 2 + ax3_0_4) |
| v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 + ax4_0_2) |
| T.reads(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16:v3_o * 16 + 16, v4_o * 16:v4_o * 16 + 16]) |
| T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) |
| T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_s8s8s32_trans", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_s32", "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "warp_execution": 1}) |
| with T.init(): |
| for ax2_1, ax3_1 in T.grid(16, 16): |
| with T.sblock("conv2d_nhwc_init"): |
| v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) |
| T.reads() |
| T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init]) |
| conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = 0 |
| for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): |
| with T.sblock("conv2d_nhwc"): |
| v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) |
| T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) |
| T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) |
| T.sblock_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) |
| conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.Cast("int32", pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("int32", p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) |
| for ax0_0, ax1_0 in T.grid(1, 2): |
| with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): |
| v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2) |
| v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax1_0) |
| T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) |
| T.writes(conv2d_nhwc_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) |
| T.sblock_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_s32_shared"}) |
| for ax0_1, ax1_1 in T.grid(16, 16): |
| with T.sblock("conv2d_nhwc_reindex_shared_wmma.accumulator"): |
| v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) |
| T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) |
| T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) |
| conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] |
| for ax0, ax1_0, ax1_1, ax1_2, ax1_3 in T.grid(32, 1, 4, 32, 2): |
| with T.sblock("conv2d_nhwc_reindex_shared"): |
| v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0) |
| v1 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + (ax1_0 * 256 + ax1_1 * 64 + ax1_2 * 2 + ax1_3)) |
| T.where(((ax1_0 * 4 + ax1_1) * 32 + ax1_2) * 2 + ax1_3 < 64) |
| T.reads(p7[()], conv2d_nhwc_reindex_shared[v0, v1], p2[0, 0, 0, v1], p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], p8[0], p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) |
| T.writes(compute[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) |
| compute[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] = T.max(T.min(T.q_multiply_shift(T.max(T.min(p7[()] + T.q_multiply_shift_per_axis(conv2d_nhwc_reindex_shared[v0, v1] - p2[0, 0, 0, v1] + p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], 31, T.bool(False), T.bool(True)), 255), 0) - p8[0], 1457846997, 31, 0) + T.q_multiply_shift(p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1], 2101000910, 31, 0), 255), 0) |
| # fmt: on |
| |
| # pylint: enable=no-member,invalid-name,unused-variable |
| |
| use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) |
| |
| |
| def test_compute_inline_elementwise(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise, debug_mask="all") |
| block_b = "B" if use_block_name else sch.get_sblock("B") |
| block_c = sch.get_sblock("C") |
| sch.compute_inline(block_b) |
| assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) |
| assert sch.get(block_c).name_hint == "C" |
| verify_trace_roundtrip(sch=sch, mod=elementwise) |
| |
| |
| def test_compute_inline_under_loop(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_under_loop, debug_mask="all") |
| block_b = "B" if use_block_name else sch.get_sblock("B") |
| block_c = sch.get_sblock("C") |
| sch.compute_inline(block_b) |
| assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) |
| assert sch.get(block_c).name_hint == "C" |
| verify_trace_roundtrip(sch=sch, mod=elementwise_under_loop) |
| |
| |
| def test_compute_inline_as_dce(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_standalone, debug_mask="all") |
| block_b = "B" if use_block_name else sch.get_sblock("B") |
| block_c = sch.get_sblock("C") |
| sch.compute_inline(block_b) |
| assert_structural_equal_ignore_global_symbol(elementwise_standalone_dce, sch.mod["main"]) |
| assert sch.get(block_c).name_hint == "C" |
| verify_trace_roundtrip(sch=sch, mod=elementwise_standalone) |
| |
| |
| def test_compute_inline_multi_consumer(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_multi_producer_consumer, debug_mask="all") |
| block_b = "B" if use_block_name else sch.get_sblock("B") |
| block_c = sch.get_sblock("C") |
| block_d = sch.get_sblock("D") |
| sch.compute_inline(block_b) |
| assert_structural_equal_ignore_global_symbol( |
| elementwise_multi_consumer_inlined, sch.mod["main"] |
| ) |
| assert sch.get(block_c).name_hint == "C" |
| assert sch.get(block_d).name_hint == "D" |
| verify_trace_roundtrip(sch=sch, mod=elementwise_multi_producer_consumer) |
| |
| |
| def test_compute_inline_fail_multi_writer(use_block_name): |
| sch = tvm.s_tir.Schedule(fail_multi_reader_writer, debug_mask="all") |
| block_b = "B" if use_block_name else sch.get_sblock("B") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.compute_inline(block_b) |
| |
| |
| def test_reverse_compute_inline_elementwise(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise, debug_mask="all") |
| block_b = sch.get_sblock("B") |
| block_c = "C" if use_block_name else sch.get_sblock("C") |
| sch.reverse_compute_inline(block_c) |
| assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) |
| assert sch.get(block_b).name_hint == "B" |
| verify_trace_roundtrip(sch=sch, mod=elementwise) |
| |
| |
| def test_reverse_compute_inline_under_loop(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_under_loop, debug_mask="all") |
| block_b = sch.get_sblock("B") |
| block_c = "C" if use_block_name else sch.get_sblock("C") |
| sch.reverse_compute_inline(block_c) |
| assert_structural_equal_ignore_global_symbol(elementwise_inlined, sch.mod["main"]) |
| assert sch.get(block_b).name_hint == "B" |
| verify_trace_roundtrip(sch=sch, mod=elementwise_under_loop) |
| |
| |
| def test_reverse_compute_inline_fail_as_dce(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_standalone, debug_mask="all") |
| block_b = "B" if use_block_name else sch.get_sblock("B") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.reverse_compute_inline(block_b) |
| |
| |
| def test_reverse_compute_inline_fail_multi_producer(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_multi_producer_consumer, debug_mask="all") |
| block_d = "D" if use_block_name else sch.get_sblock("D") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.reverse_compute_inline(block_d) |
| |
| |
| def test_reverse_compute_inline_fail_multi_reader(use_block_name): |
| sch = tvm.s_tir.Schedule(fail_multi_reader_writer, debug_mask="all") |
| block_c = "C" if use_block_name else sch.get_sblock("C") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.reverse_compute_inline(block_c) |
| |
| |
| def test_reverse_compute_multi_reverse_loads(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_multi_reverse_loads, debug_mask="all") |
| block_c = "C" if use_block_name else sch.get_sblock("C") |
| sch.reverse_compute_inline(block_c) |
| assert_structural_equal_ignore_global_symbol( |
| elementwise_multi_reverse_loads_inlined, sch.mod["main"] |
| ) |
| verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_loads) |
| |
| |
| def test_reverse_compute_inline_affine_load(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_reverse_affine_load, debug_mask="all") |
| block_c = "C" if use_block_name else sch.get_sblock("C") |
| sch.reverse_compute_inline(block_c) |
| assert_structural_equal_ignore_global_symbol( |
| elementwise_reverse_affine_load_inlined, sch.mod["main"] |
| ) |
| verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load) |
| |
| |
| def test_reverse_compute_inline_multi_affine_load(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_multi_reverse_affine_load, debug_mask="all") |
| block_c = "C" if use_block_name else sch.get_sblock("C") |
| sch.reverse_compute_inline(block_c) |
| assert_structural_equal_ignore_global_symbol( |
| elementwise_multi_reverse_affine_load_inlined, sch.mod["main"] |
| ) |
| verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_affine_load) |
| |
| |
| def test_reverse_compute_inline_affine_load_unit_iter(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_reverse_affine_load_unit_iter, debug_mask="all") |
| block_c = "C" if use_block_name else sch.get_sblock("C") |
| sch.reverse_compute_inline(block_c) |
| assert_structural_equal_ignore_global_symbol( |
| elementwise_reverse_affine_load_unit_iter_inlined, sch.mod["main"] |
| ) |
| verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load_unit_iter) |
| |
| |
| def test_reverse_compute_inline_affine_load_unit_iter_simplified(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_reverse_affine_load_unit_iter_simplified, debug_mask="all") |
| block_c = "C" if use_block_name else sch.get_sblock("C") |
| sch.reverse_compute_inline(block_c) |
| assert_structural_equal_ignore_global_symbol( |
| elementwise_reverse_affine_load_unit_iter_simplified_inlined, sch.mod["main"] |
| ) |
| verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load_unit_iter_simplified) |
| |
| |
| @pytest.mark.parametrize("reverse_order", [True, False]) |
| def test_reverse_compute_inline_affine_chain(use_block_name, reverse_order): |
| sch = tvm.s_tir.Schedule(elementwise_reverse_affine_chain, debug_mask="all") |
| block_c = "C" if use_block_name else sch.get_sblock("C") |
| block_d = "D" if use_block_name else sch.get_sblock("D") |
| if reverse_order: |
| sch.reverse_compute_inline(block_d) |
| sch.reverse_compute_inline(block_c) |
| else: |
| sch.reverse_compute_inline(block_c) |
| sch.reverse_compute_inline(block_d) |
| assert_structural_equal_ignore_global_symbol( |
| elementwise_reverse_affine_chain_inlined, sch.mod["main"] |
| ) |
| verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_chain) |
| |
| |
| def test_reverse_compute_fail_non_affine_load(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_reverse_non_affine_load, debug_mask="all") |
| block_c = "C" if use_block_name else sch.get_sblock("C") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.reverse_compute_inline(block_c) |
| |
| |
| def test_reverse_compute_fail_multi_reverse_loads(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_multi_loads, debug_mask="all") |
| block_c = "C" if use_block_name else sch.get_sblock("C") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.reverse_compute_inline(block_c) |
| |
| |
| def test_opaque_access_load(use_block_name): |
| sch = tvm.s_tir.Schedule(opaque_access_load, debug_mask="all") |
| block_b = "B" if use_block_name else sch.get_sblock("B") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.compute_inline(block_b) |
| |
| |
| def test_opaque_access_store(use_block_name): |
| sch = tvm.s_tir.Schedule(opaque_access_store, debug_mask="all") |
| block_b = "B" if use_block_name else sch.get_sblock("B") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.compute_inline(block_b) |
| |
| |
| def test_buffer_matched(use_block_name): |
| sch = tvm.s_tir.Schedule(buffer_matched, debug_mask="all") |
| block_b = "B" if use_block_name else sch.get_sblock("B") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.compute_inline(block_b) |
| |
| |
| def test_output_block(use_block_name): |
| sch = tvm.s_tir.Schedule(matmul_relu, debug_mask="all") |
| block = sch.get_sblock("compute") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.compute_inline(block) |
| |
| sch = tvm.s_tir.Schedule(elementwise_output, debug_mask="all") |
| block = sch.get_sblock("B") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.compute_inline(block) |
| |
| block = sch.get_sblock("C") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.reverse_compute_inline(block) |
| |
| |
| def test_compute_inline_predicate(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_predicate, debug_mask="all") |
| block_b = "B" if use_block_name else sch.get_sblock("B") |
| sch.compute_inline(block_b) |
| assert_structural_equal_ignore_global_symbol(elementwise_predicate_inlined, sch.mod["main"]) |
| verify_trace_roundtrip(sch=sch, mod=elementwise_predicate) |
| |
| |
| def test_compute_inline_multi_loads(use_block_name): |
| sch = tvm.s_tir.Schedule(elementwise_multi_loads, debug_mask="all") |
| block_b = "B" if use_block_name else sch.get_sblock("B") |
| sch.compute_inline(block_b) |
| assert_structural_equal_ignore_global_symbol(elementwise_multi_loads_inlined, sch.mod["main"]) |
| verify_trace_roundtrip(sch=sch, mod=elementwise_multi_loads) |
| |
| |
| def test_compute_inline_with_opaque_access(use_block_name): |
| """Test not rewrite opaque reads/writes after irrelavant compute inline""" |
| sch = tvm.s_tir.Schedule(access_opaque_ptr_then_elemwise, debug_mask="all") |
| BB = "BB" if use_block_name else sch.get_sblock("BB") |
| sch.compute_inline(BB) |
| assert_structural_equal_ignore_global_symbol( |
| access_opaque_ptr_then_elemwise_inline, sch.mod["main"] |
| ) |
| |
| |
| def test_inline_block_with_init(): |
| sch = tvm.s_tir.Schedule(inline_block_with_init, debug_mask="all") |
| block = sch.get_sblock(name="tensor_rf", func_name="main") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.compute_inline(block=block) |
| |
| |
| def test_compute_inline_opaque_access_with_tvm_access_ptr(use_block_name): |
| """Test opaque access with tvm_access_ptr after compute inline""" |
| sch = tvm.s_tir.Schedule(exp_exp_opaque_access_with_tvm_access_ptr, debug_mask="all") |
| compute = "compute" if use_block_name else sch.get_sblock("compute") |
| sch.compute_inline(compute) |
| assert_structural_equal_ignore_global_symbol( |
| exp_exp_opaque_access_with_tvm_access_ptr_inlined, sch.mod["main"] |
| ) |
| |
| |
| def test_reverse_compute_inline_overcomputed_producer(use_block_name): |
| """Test reverse compute inline overcomputed producer""" |
| sch = tvm.s_tir.Schedule(elementwise_overcomputed_producer, debug_mask="all") |
| compute = "C" if use_block_name else sch.get_sblock("C") |
| sch.reverse_compute_inline(compute) |
| assert_structural_equal_ignore_global_symbol( |
| elementwise_overcomputed_producer_reverse_inlined, sch.mod["main"] |
| ) |
| |
| |
| def test_reverse_compute_inline_overcomputed_producer_simplify_predicate(use_block_name): |
| """Test reverse compute inline overcomputed producer where the predicate should be simplified""" |
| sch = tvm.s_tir.Schedule(elementwise_overcomputed_producer_simplify_predicate, debug_mask="all") |
| compute = "C" if use_block_name else sch.get_sblock("C") |
| sch.reverse_compute_inline(compute) |
| assert_structural_equal_ignore_global_symbol( |
| elementwise_overcomputed_producer_simplify_predicate_reverse_inlined, sch.mod["main"] |
| ) |
| |
| |
| def test_reverse_compute_inline_overcomputed_producer_injective_load(use_block_name): |
| """Test reverse compute inline overcomputed producer with injective buffer load""" |
| sch = tvm.s_tir.Schedule(elementwise_overcomputed_producer_injective_load, debug_mask="all") |
| compute = "C" if use_block_name else sch.get_sblock("C") |
| sch.reverse_compute_inline(compute) |
| assert_structural_equal_ignore_global_symbol( |
| elementwise_overcomputed_producer_injective_load_reverse_inlined, sch.mod["main"] |
| ) |
| |
| |
| def test_reverse_compute_inline_error_producer_not_cover_consumer(use_block_name): |
| """Test reverse compute inline failure when the inlined block iter domains are not covered by |
| its producer |
| """ |
| sch = tvm.s_tir.Schedule(elementwise_producer_not_cover_consumer, debug_mask="all") |
| compute = "C" if use_block_name else sch.get_sblock("C") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.reverse_compute_inline(compute) |
| |
| |
| def test_reverse_compute_inline_producer_predicate_allowed(): |
| """Test a case where reverse compute inline is allowed even though the producer has a |
| non-trivial predicate. |
| """ |
| |
| sch = tvm.s_tir.Schedule(elementwise_predicate_producer, debug_mask="all") |
| sch.reverse_compute_inline(sch.get_sblock("C")) |
| assert_structural_equal_ignore_global_symbol( |
| elementwise_predicate_producer_inlined, sch.mod["main"] |
| ) |
| |
| |
| def test_reverse_compute_inline_producer_predicate_disallowed(): |
| """Test reverse compute inline failure when the producer has a non-trivial predicate that cannot be |
| implied by the synthesized predicate of the new inlined block. |
| """ |
| |
| sch = tvm.s_tir.Schedule(Conv2dInt8_TensorCore_with_predicate_before, debug_mask="all") |
| sch.reverse_compute_inline(sch.get_sblock("compute_4")) |
| assert_structural_equal_ignore_global_symbol( |
| Conv2dInt8_TensorCore_with_predicate_after["main"], sch.mod["main"] |
| ) |
| |
| |
| def test_reverse_compute_inline_producer_is_reduction(): |
| """Test reverse comput inline when producer is reduction""" |
| sch = tvm.s_tir.Schedule(elementwise_producer_is_reduction, debug_mask="all") |
| with pytest.raises(tvm.s_tir.ScheduleError): |
| sch.reverse_compute_inline(sch.get_sblock("C")) |
| |
| |
| def test_compute_inline_softmax(): |
| # fmt: off |
| @T.prim_func |
| def before(p_lv44: T.handle, p_output0: T.handle): |
| T.func_attr({"tir.noalias": True}) |
| n, m = T.int64(), T.int64() |
| lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m)) |
| var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") |
| T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) |
| T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) |
| T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) |
| var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) |
| for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): |
| with T.sblock("T_softmax_maxelem"): |
| v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) |
| T.reads(lv44[v_i0, v_i1, v_i2, v_k]) |
| T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) |
| with T.init(): |
| T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) |
| T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv44[v_i0, v_i1, v_i2, v_k]) |
| for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): |
| with T.sblock("T_softmax_exp"): |
| v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) |
| T.reads(lv44[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) |
| T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) |
| T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv44[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) |
| for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): |
| with T.sblock("T_softmax_expsum"): |
| v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) |
| T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) |
| T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) |
| with T.init(): |
| T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) |
| T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] |
| for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): |
| with T.sblock("T_softmax_norm"): |
| v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) |
| T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) |
| T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) |
| T.sblock_attr({"axis": 3}) |
| var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] |
| for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): |
| with T.sblock("compute"): |
| v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) |
| T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) |
| T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) |
| var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) |
| |
| @T.prim_func |
| def after(p_lv44: T.handle, p_output0: T.handle): |
| T.func_attr({"tir.noalias": True}) |
| n, m = T.int64(), T.int64() |
| lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m)) |
| var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") |
| # with T.sblock("root"): |
| T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) |
| T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) |
| var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) |
| for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): |
| with T.sblock("T_softmax_maxelem"): |
| v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) |
| T.reads(lv44[v_i0, v_i1, v_i2, v_k]) |
| T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) |
| with T.init(): |
| T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) |
| T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv44[v_i0, v_i1, v_i2, v_k]) |
| for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): |
| with T.sblock("T_softmax_expsum"): |
| v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) |
| T.reads(lv44[v_i0, v_i1, v_i2, v_k], T_softmax_maxelem[v_i0, v_i1, v_i2]) |
| T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) |
| with T.init(): |
| T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) |
| T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T.exp(lv44[v_i0, v_i1, v_i2, v_k] - T_softmax_maxelem[v_i0, v_i1, v_i2]) |
| for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): |
| with T.sblock("T_softmax_norm"): |
| v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) |
| T.reads(lv44[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1, v_i2]) |
| T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) |
| T.sblock_attr({"axis": 3}) |
| var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T.exp(lv44[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) / T_softmax_expsum[v_i0, v_i1, v_i2] |
| for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): |
| with T.sblock("compute"): |
| v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) |
| T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) |
| T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) |
| var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) |
| # fmt: on |
| |
| sch = tvm.s_tir.Schedule(before) |
| sch.compute_inline(sch.get_sblock("T_softmax_exp")) |
| assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) |
| |
| |
| def test_reverse_compute_inline_layer_norm(): |
| # fmt: off |
| @T.prim_func |
| def before(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| n = T.int64() |
| lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) |
| var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") |
| A_red_temp_v0_shared = T.alloc_buffer((T.int64(1), n), scope="shared") |
| A_red_temp_v1_shared = T.alloc_buffer((T.int64(1), n), scope="shared") |
| var_T_layer_norm_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) |
| for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): |
| for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(1), T.int64(10)): |
| for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): |
| with T.sblock("A_red_temp"): |
| v_ax0 = T.axis.spatial(T.int64(1), ax0) |
| v_ax1 = T.axis.spatial(n, ax0_ax1_fused + ax1) |
| v_k2 = T.axis.reduce(T.int64(2560), ax2_0 * T.int64(256) + ax2_1) |
| T.reads(lv6[v_ax0, v_ax1, v_k2]) |
| T.writes(A_red_temp_v0_shared[v_ax0, v_ax1], A_red_temp_v1_shared[v_ax0, v_ax1]) |
| with T.init(): |
| A_red_temp_v0_shared[v_ax0, v_ax1] = T.float32(0) |
| A_red_temp_v1_shared[v_ax0, v_ax1] = T.float32(0) |
| v_A_red_temp_v0: T.float32 = A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] |
| v_A_red_temp_v1: T.float32 = A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] |
| A_red_temp_v0_shared[v_ax0, v_ax1] = v_A_red_temp_v0 |
| A_red_temp_v1_shared[v_ax0, v_ax1] = v_A_red_temp_v1 |
| for ax2_0 in range(T.int64(10)): |
| for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): |
| with T.sblock("T_layer_norm"): |
| v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) |
| v_ax1 = T.axis.spatial(n, ax0_ax1_fused) |
| v_ax2 = T.axis.spatial(T.int64(2560), ax2_0 * T.int64(256) + ax2_1) |
| T.reads(lv6[v_ax0, v_ax1, v_ax2], A_red_temp_v0_shared[v_ax0, v_ax1], A_red_temp_v1_shared[v_ax0, v_ax1], weight1[v_ax2], bias[v_ax2]) |
| T.writes(var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2]) |
| var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2] = (lv6[v_ax0, v_ax1, v_ax2] - A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * weight1[v_ax2] + bias[v_ax2] |
| for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): |
| with T.sblock("compute"): |
| v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) |
| T.reads(var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) |
| T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) |
| var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) |
| |
| @T.prim_func |
| def after(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| n = T.int64() |
| lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) |
| var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") |
| # with T.sblock("root"): |
| A_red_temp_v0_shared = T.alloc_buffer((T.int64(1), n), scope="shared") |
| A_red_temp_v1_shared = T.alloc_buffer((T.int64(1), n), scope="shared") |
| for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): |
| for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(1), T.int64(10)): |
| for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): |
| with T.sblock("A_red_temp"): |
| v_ax0 = T.axis.spatial(T.int64(1), ax0) |
| v_ax1 = T.axis.spatial(n, ax0_ax1_fused + ax1) |
| v_k2 = T.axis.reduce(T.int64(2560), ax2_0 * T.int64(256) + ax2_1) |
| T.reads(lv6[v_ax0, v_ax1, v_k2]) |
| T.writes(A_red_temp_v0_shared[v_ax0, v_ax1], A_red_temp_v1_shared[v_ax0, v_ax1]) |
| with T.init(): |
| A_red_temp_v0_shared[v_ax0, v_ax1] = T.float32(0) |
| A_red_temp_v1_shared[v_ax0, v_ax1] = T.float32(0) |
| v_A_red_temp_v0: T.float32 = A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] |
| v_A_red_temp_v1: T.float32 = A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] |
| A_red_temp_v0_shared[v_ax0, v_ax1] = v_A_red_temp_v0 |
| A_red_temp_v1_shared[v_ax0, v_ax1] = v_A_red_temp_v1 |
| for ax2_0 in range(T.int64(10)): |
| for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): |
| with T.sblock("T_layer_norm"): |
| v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) |
| v_ax1 = T.axis.spatial(n, ax0_ax1_fused) |
| v_ax2 = T.axis.spatial(T.int64(2560), ax2_0 * T.int64(256) + ax2_1) |
| T.reads(lv6[v_ax0, v_ax1, v_ax2], A_red_temp_v0_shared[v_ax0, v_ax1], A_red_temp_v1_shared[v_ax0, v_ax1], weight1[v_ax2], bias[v_ax2]) |
| T.writes(var_compute_intermediate[v_ax0, v_ax1, v_ax2]) |
| var_compute_intermediate[v_ax0, v_ax1, v_ax2] = T.Cast("float16", (lv6[v_ax0, v_ax1, v_ax2] - A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0_shared[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * weight1[v_ax2] + bias[v_ax2]) |
| # fmt: on |
| |
| sch = tvm.s_tir.Schedule(before) |
| sch.reverse_compute_inline(sch.get_sblock("compute")) |
| assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) |
| |
| |
| def test_reverse_compute_inline_slicing_then_cachewrite(): |
| @T.prim_func |
| def before( |
| x: T.Buffer((1, 16, 7, 7), "float32"), |
| T_strided_slice_with_axes: T.Buffer((1, 12, 7, 7), "float32"), |
| ): |
| T_add = T.alloc_buffer((1, 16, 7, 7)) |
| for ax0, ax1, ax2, ax3 in T.grid(1, 16, 7, 7): |
| with T.sblock("T_add"): |
| v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) |
| T_add[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1) |
| for ax0, ax1, ax2, ax3 in T.grid(1, 12, 7, 7): |
| with T.sblock("T_strided_slice_with_axes"): |
| v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) |
| T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3] = T_add[ |
| v_ax0, v_ax1, v_ax2, v_ax3 |
| ] |
| |
| @T.prim_func |
| def after( |
| x: T.Buffer((1, 16, 7, 7), "float32"), |
| T_strided_slice_with_axes: T.Buffer((1, 12, 7, 7), "float32"), |
| ): |
| T_strided_slice_with_axes_global = T.alloc_buffer((1, 12, 7, 7)) |
| for ax0, ax1, ax2, ax3 in T.grid(1, 16, 7, 7): |
| with T.sblock("T_add"): |
| v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) |
| T.where(ax1 < 12) |
| T_strided_slice_with_axes_global[v_ax0, v_ax1, v_ax2, v_ax3] = x[ |
| v_ax0, v_ax1, v_ax2, v_ax3 |
| ] + T.float32(1) |
| for ax0, ax1, ax2, ax3 in T.grid(1, 12, 7, 7): |
| with T.sblock("T_strided_slice_with_axes_global"): |
| v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) |
| T_strided_slice_with_axes[v0, v1, v2, v3] = T_strided_slice_with_axes_global[ |
| v0, v1, v2, v3 |
| ] |
| |
| sch = tvm.s_tir.Schedule(before) |
| sch.reverse_compute_inline(sch.get_sblock("T_strided_slice_with_axes")) |
| sch.cache_write(sch.get_sblock("T_add"), 0, "global") |
| assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) |
| |
| |
| def test_inline_with_reduction(): |
| @T.prim_func |
| def before( |
| T_softmax_norm: T.Buffer((T.int64(6), T.int64(1), T.int64(1)), "float32"), |
| T_reshape_2: T.Buffer((T.int64(6), T.int64(1), T.int64(64)), "float32"), |
| T_transpose: T.Buffer((T.int64(1), T.int64(1), T.int64(6), T.int64(64)), "float32"), |
| ): |
| T_batch_matmul_NN = T.alloc_buffer((T.int64(6), T.int64(1), T.int64(64))) |
| for ax0, ax1 in T.grid(T.int64(6), T.int64(64)): |
| with T.sblock("bmm"): |
| v0, v1 = T.axis.remap("SS", [ax0, ax1]) |
| T.reads(T_softmax_norm[v0, T.int64(0), T.int64(0)], T_reshape_2[v0, T.int64(0), v1]) |
| T.writes(T_batch_matmul_NN[v0, T.int64(0), v1]) |
| with T.init(): |
| T_batch_matmul_NN[v0, T.int64(0), v1] = T.float32(0.0) |
| T_batch_matmul_NN[v0, T.int64(0), v1] = ( |
| T_batch_matmul_NN[v0, T.int64(0), v1] |
| + T_softmax_norm[v0, T.int64(0), T.int64(0)] * T_reshape_2[v0, T.int64(0), v1] |
| ) |
| for ax0, ax1 in T.grid(T.int64(6), T.int64(64)): |
| with T.sblock("transpose"): |
| v0, v1 = T.axis.remap("SS", [ax0, ax1]) |
| T.reads(T_batch_matmul_NN[v0, T.int64(0), v1]) |
| T.writes(T_transpose[T.int64(0), T.int64(0), v0, v1]) |
| T_transpose[T.int64(0), T.int64(0), v0, v1] = T_batch_matmul_NN[v0, T.int64(0), v1] |
| |
| @T.prim_func |
| def after( |
| T_softmax_norm: T.Buffer((T.int64(6), T.int64(1), T.int64(1)), "float32"), |
| T_reshape_2: T.Buffer((T.int64(6), T.int64(1), T.int64(64)), "float32"), |
| T_transpose: T.Buffer((T.int64(1), T.int64(1), T.int64(6), T.int64(64)), "float32"), |
| ): |
| for ax0, ax1 in T.grid(T.int64(6), T.int64(64)): |
| with T.sblock("bmm"): |
| v0, v1 = T.axis.remap("SS", [ax0, ax1]) |
| T.reads(T_softmax_norm[v0, T.int64(0), T.int64(0)], T_reshape_2[v0, T.int64(0), v1]) |
| T.writes(T_transpose[T.int64(0), T.int64(0), v0, v1]) |
| with T.init(): |
| T_transpose[T.int64(0), T.int64(0), v0, v1] = T.float32(0.0) |
| T_transpose[T.int64(0), T.int64(0), v0, v1] = ( |
| T_transpose[T.int64(0), T.int64(0), v0, v1] |
| + T_softmax_norm[v0, T.int64(0), T.int64(0)] * T_reshape_2[v0, T.int64(0), v1] |
| ) |
| |
| sch = tvm.s_tir.Schedule(before) |
| sch.reverse_compute_inline(sch.get_sblock("transpose")) |
| assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |