|  | # Licensed to the Apache Software Foundation (ASF) under one | 
|  | # or more contributor license agreements.  See the NOTICE file | 
|  | # distributed with this work for additional information | 
|  | # regarding copyright ownership.  The ASF licenses this file | 
|  | # to you under the Apache License, Version 2.0 (the | 
|  | # "License"); you may not use this file except in compliance | 
|  | # with the License.  You may obtain a copy of the License at | 
|  | # | 
|  | #   http://www.apache.org/licenses/LICENSE-2.0 | 
|  | # | 
|  | # Unless required by applicable law or agreed to in writing, | 
|  | # software distributed under the License is distributed on an | 
|  | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | 
|  | # KIND, either express or implied.  See the License for the | 
|  | # specific language governing permissions and limitations | 
|  | # under the License. | 
|  | # pylint: disable=missing-docstring,line-too-long,invalid-name,too-few-public-methods,too-many-locals | 
|  |  | 
|  | import tvm.testing | 
|  | from tvm import dlight as dl | 
|  | from tvm.ir import assert_structural_equal | 
|  | from tvm.script import ir as I | 
|  | from tvm.script import tir as T | 
|  | from tvm.target import Target | 
|  |  | 
|  |  | 
|  | def test_decode_gemv_1(): | 
|  | # NK layout + K as decode dim | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Before: | 
|  | @T.prim_func | 
|  | def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): | 
|  | T.func_attr({"global_symbol": "main", "tir.noalias": True}) | 
|  | # with T.block("root"): | 
|  | B = T.alloc_buffer((4096, 4096), "float16") | 
|  | for i, j in T.grid(4096, 4096): | 
|  | with T.block("decode"): | 
|  | v_i, v_j = T.axis.remap("SS", [i, j]) | 
|  | T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32]) | 
|  | T.writes(B[v_i, v_j]) | 
|  | B[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32] | 
|  | for i0, i1, i2, k in T.grid(1, 1, 4096, 4096): | 
|  | with T.block("matmul"): | 
|  | v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) | 
|  | T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k]) | 
|  | T.writes(C[v_i0, v_i1, v_i2]) | 
|  | with T.init(): | 
|  | C[v_i0, v_i1, v_i2] = T.float16(0) | 
|  | C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1, v_k] * B[v_i2, v_k] | 
|  |  | 
|  |  | 
|  | @I.ir_module | 
|  | class After: | 
|  | @T.prim_func | 
|  | def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T.handle): | 
|  | T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) | 
|  | W = T.match_buffer(W_handle, (4096, 512), "uint32") | 
|  | S = T.match_buffer(S_handle, (4096, 128), "float16") | 
|  | V = T.match_buffer(V_handle, (1, 1, 4096), "float16") | 
|  | C = T.match_buffer(C_handle, (1, 1, 4096), "float16") | 
|  | with T.block("root"): | 
|  | T.reads() | 
|  | T.writes() | 
|  | C_rf_local = T.alloc_buffer((512, 1, 1, 4096), "float16", scope="local") | 
|  | for ax0_fused in T.thread_binding(4096, thread="blockIdx.x"): | 
|  | for ax1_0_fused_1 in T.thread_binding(512, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | 
|  | with T.block("matmul_rf_init"): | 
|  | vax1_0_fused_1 = T.axis.spatial(512, ax1_0_fused_1) | 
|  | v0 = T.axis.spatial(4096, ax0_fused) | 
|  | T.reads() | 
|  | T.writes(C_rf_local[vax1_0_fused_1, 0, 0, v0]) | 
|  | C_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0) | 
|  | for ax1_0_fused_0 in range(1): | 
|  | for ax1_1 in range(8): | 
|  | with T.block("matmul_rf_update"): | 
|  | vax1_0_fused_1 = T.axis.spatial(512, ax1_0_fused_1) | 
|  | v0 = T.axis.spatial(4096, ax0_fused) | 
|  | vax1_0_fused_0 = T.axis.reduce(1, ax1_0_fused_0) | 
|  | vax1_1 = T.axis.reduce(8, ax1_1) | 
|  | T.reads(C_rf_local[vax1_0_fused_1, 0, 0, v0], V[0, 0, vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1], W[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 8], S[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 32]) | 
|  | T.writes(C_rf_local[vax1_0_fused_1, 0, 0, v0]) | 
|  | C_rf_local[vax1_0_fused_1, 0, 0, v0] = C_rf_local[vax1_0_fused_1, 0, 0, v0] + V[0, 0, vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 32]) | 
|  | for ax1_fused in range(1): | 
|  | for ax0 in T.thread_binding(512, thread="threadIdx.x"): | 
|  | with T.block("matmul"): | 
|  | vax1_0_fused_1 = T.axis.reduce(512, ax0) | 
|  | v0 = T.axis.spatial(4096, ax0_fused) | 
|  | T.reads(C_rf_local[vax1_0_fused_1, 0, 0, v0]) | 
|  | T.writes(C[0, 0, v0]) | 
|  | with T.init(): | 
|  | C[0, 0, v0] = T.float16(0) | 
|  | C[0, 0, v0] = C[0, 0, v0] + C_rf_local[vax1_0_fused_1, 0, 0, v0] | 
|  | # fmt: on | 
|  |  | 
|  | target = Target("nvidia/geforce-rtx-3090-ti") | 
|  | with target: | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, After) | 
|  |  | 
|  |  | 
|  | def test_decode_gemv_2(): | 
|  | # KN layout + K as decode dim | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Before: | 
|  | @T.prim_func | 
|  | def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): | 
|  | T.func_attr({"global_symbol": "main", "tir.noalias": True}) | 
|  | # with T.block("root"): | 
|  | B = T.alloc_buffer((4096, 4096), "float16") | 
|  | for i, j in T.grid(4096, 4096): | 
|  | with T.block("decode"): | 
|  | v_i, v_j = T.axis.remap("SS", [i, j]) | 
|  | T.reads(W[v_i // 8, v_j], S[v_i // 32, v_j]) | 
|  | T.writes(B[v_i, v_j]) | 
|  | B[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i // 32, v_j] | 
|  | for i0, i1, i2, k in T.grid(1, 1, 4096, 4096): | 
|  | with T.block("matmul"): | 
|  | v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) | 
|  | T.reads(V[v_i0, v_i1, v_k], B[v_k, v_i2]) | 
|  | T.writes(C[v_i0, v_i1, v_i2]) | 
|  | with T.init(): | 
|  | C[v_i0, v_i1, v_i2] = T.float16(0) | 
|  | C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1, v_k] * B[v_k, v_i2] | 
|  |  | 
|  |  | 
|  | @I.ir_module | 
|  | class After: | 
|  | @T.prim_func | 
|  | def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): | 
|  | T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) | 
|  | # with T.block("root"): | 
|  | C_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local") | 
|  | for i2_i0_i1_fused_0 in T.thread_binding(256, thread="blockIdx.x"): | 
|  | for i2_i0_i1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): | 
|  | for k_0_fused_1 in T.thread_binding(16, thread="threadIdx.y"): | 
|  | with T.block("matmul_rf_init"): | 
|  | vk_0_fused_1 = T.axis.spatial(16, k_0_fused_1) | 
|  | v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 * 16 + i2_i0_i1_fused_1) | 
|  | C_rf_local[vk_0_fused_1, 0, 0, v_i2] = T.float16(0) | 
|  | for k_0_fused_0, k_1 in T.grid(32, 8): | 
|  | with T.block("matmul_rf_update"): | 
|  | vk_0_fused_1 = T.axis.spatial(16, k_0_fused_1) | 
|  | v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 * 16 + i2_i0_i1_fused_1) | 
|  | vk_0_fused_0, vk_1 = T.axis.remap("RR", [k_0_fused_0, k_1]) | 
|  | C_rf_local[vk_0_fused_1, 0, 0, v_i2] = C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[(vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) // 8, v_i2], T.Cast("uint32", (vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[(vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) // 32, v_i2]) | 
|  | for ax1_ax2_ax3_fused in T.thread_binding(16, thread="threadIdx.x"): | 
|  | for ax0_fused in T.thread_binding(16, thread="threadIdx.y"): | 
|  | with T.block("matmul"): | 
|  | vk_0_fused_1 = T.axis.reduce(16, ax0_fused) | 
|  | v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 * 16 + ax1_ax2_ax3_fused) | 
|  | with T.init(): | 
|  | C[0, 0, v_i2] = T.float16(0) | 
|  | C[0, 0, v_i2] = C[0, 0, v_i2] + C_rf_local[vk_0_fused_1, 0, 0, v_i2] | 
|  |  | 
|  | # fmt: on | 
|  |  | 
|  | target = Target("nvidia/geforce-rtx-3090-ti") | 
|  | with target: | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, After) | 
|  |  | 
|  |  | 
|  | def test_decode_gemv_3(): | 
|  | # NK layout + N as decode dim | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Before: | 
|  | @T.prim_func | 
|  | def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): | 
|  | T.func_attr({"global_symbol": "main", "tir.noalias": True}) | 
|  | # with T.block("root"): | 
|  | B = T.alloc_buffer((4096, 4096), "float16") | 
|  | for i, j in T.grid(4096, 4096): | 
|  | with T.block("decode"): | 
|  | v_i, v_j = T.axis.remap("SS", [i, j]) | 
|  | T.reads(W[v_i // 8, v_j], S[v_i // 32, v_j]) | 
|  | T.writes(B[v_i, v_j]) | 
|  | B[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i // 32, v_j] | 
|  | for i0, i1, i2, k in T.grid(1, 1, 4096, 4096): | 
|  | with T.block("matmul"): | 
|  | v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) | 
|  | T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k]) | 
|  | T.writes(C[v_i0, v_i1, v_i2]) | 
|  | with T.init(): | 
|  | C[v_i0, v_i1, v_i2] = T.float16(0) | 
|  | C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1, v_k] * B[v_i2, v_k] | 
|  |  | 
|  | @I.ir_module | 
|  | class After: | 
|  | @T.prim_func | 
|  | def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T.handle): | 
|  | T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) | 
|  | W = T.match_buffer(W_handle, (512, 4096), "uint32") | 
|  | S = T.match_buffer(S_handle, (128, 4096), "float16") | 
|  | V = T.match_buffer(V_handle, (1, 1, 4096), "float16") | 
|  | C = T.match_buffer(C_handle, (1, 1, 4096), "float16") | 
|  | with T.block("root"): | 
|  | T.reads() | 
|  | T.writes() | 
|  | C_rf_local = T.alloc_buffer((1024, 1, 1, 4096), "float16", scope="local") | 
|  | for ax0_0_fused in T.thread_binding(512, thread="blockIdx.x"): | 
|  | for ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | 
|  | for ax0_1_init in range(8): | 
|  | with T.block("matmul_rf_init"): | 
|  | vax1_fused_1 = T.axis.spatial(1024, ax1_fused_1) | 
|  | v0 = T.axis.spatial(4096, ax0_0_fused * 8 + ax0_1_init) | 
|  | T.reads() | 
|  | T.writes(C_rf_local[vax1_fused_1, 0, 0, v0]) | 
|  | C_rf_local[vax1_fused_1, 0, 0, v0] = T.float16(0) | 
|  | for ax1_fused_0 in range(4): | 
|  | for ax0_1 in range(8): | 
|  | with T.block("matmul_rf_update"): | 
|  | vax1_fused_1 = T.axis.spatial(1024, ax1_fused_1) | 
|  | v0 = T.axis.spatial(4096, ax0_0_fused * 8 + ax0_1) | 
|  | vax1_fused_0 = T.axis.reduce(4, ax1_fused_0) | 
|  | T.reads(C_rf_local[vax1_fused_1, 0, 0, v0], V[0, 0, vax1_fused_0 * 1024 + vax1_fused_1], W[v0 // 8, vax1_fused_0 * 1024 + vax1_fused_1], S[v0 // 32, vax1_fused_0 * 1024 + vax1_fused_1]) | 
|  | T.writes(C_rf_local[vax1_fused_1, 0, 0, v0]) | 
|  | C_rf_local[vax1_fused_1, 0, 0, v0] = C_rf_local[vax1_fused_1, 0, 0, v0] + V[0, 0, vax1_fused_0 * 1024 + vax1_fused_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[v0 // 8, vax1_fused_0 * 1024 + vax1_fused_1], T.Cast("uint32", v0 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v0 // 32, vax1_fused_0 * 1024 + vax1_fused_1]) | 
|  | for ax1_fused_0 in range(1): | 
|  | for ax0 in T.thread_binding(1024, thread="threadIdx.x"): | 
|  | for ax1_fused_1 in range(8): | 
|  | with T.block("matmul"): | 
|  | vax1_fused_1 = T.axis.reduce(1024, ax0) | 
|  | v0 = T.axis.spatial(4096, ax0_0_fused * 8 + ax1_fused_1) | 
|  | T.reads(C_rf_local[vax1_fused_1, 0, 0, v0]) | 
|  | T.writes(C[0, 0, v0]) | 
|  | with T.init(): | 
|  | C[0, 0, v0] = T.float16(0) | 
|  | C[0, 0, v0] = C[0, 0, v0] + C_rf_local[vax1_fused_1, 0, 0, v0] | 
|  |  | 
|  | # fmt: on | 
|  |  | 
|  | target = Target("nvidia/geforce-rtx-3090-ti") | 
|  | with target: | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, After) | 
|  |  | 
|  |  | 
|  | def test_decode_gemv_4(): | 
|  | # KN layout + N as decode dim | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Before: | 
|  | @T.prim_func | 
|  | def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): | 
|  | T.func_attr({"global_symbol": "main", "tir.noalias": True}) | 
|  | # with T.block("root"): | 
|  | B = T.alloc_buffer((4096, 4096), "float16") | 
|  | for i, j in T.grid(4096, 4096): | 
|  | with T.block("decode"): | 
|  | v_i, v_j = T.axis.remap("SS", [i, j]) | 
|  | T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32]) | 
|  | T.writes(B[v_i, v_j]) | 
|  | B[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32] | 
|  | for i0, i1, i2, k in T.grid(1, 1, 4096, 4096): | 
|  | with T.block("matmul"): | 
|  | v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) | 
|  | T.reads(V[v_i0, v_i1, v_k], B[v_k, v_i2]) | 
|  | T.writes(C[v_i0, v_i1, v_i2]) | 
|  | with T.init(): | 
|  | C[v_i0, v_i1, v_i2] = T.float16(0) | 
|  | C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1, v_k] * B[v_k, v_i2] | 
|  |  | 
|  |  | 
|  | @I.ir_module | 
|  | class After: | 
|  | @T.prim_func | 
|  | def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): | 
|  | T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) | 
|  | # with T.block("root"): | 
|  | C_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local") | 
|  | for i2_0_i0_i1_fused_0 in T.thread_binding(32, thread="blockIdx.x"): | 
|  | for i2_0_i0_i1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): | 
|  | for k_fused_1 in T.thread_binding(16, thread="threadIdx.y"): | 
|  | for i2_1_init in range(8): | 
|  | with T.block("matmul_rf_init"): | 
|  | vk_fused_1 = T.axis.spatial(16, k_fused_1) | 
|  | v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 * 128 + i2_0_i0_i1_fused_1 * 8 + i2_1_init) | 
|  | C_rf_local[vk_fused_1, 0, 0, v_i2] = T.float16(0) | 
|  | for k_fused_0, i2_1 in T.grid(256, 8): | 
|  | with T.block("matmul_rf_update"): | 
|  | vk_fused_1 = T.axis.spatial(16, k_fused_1) | 
|  | v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 * 128 + i2_0_i0_i1_fused_1 * 8 + i2_1) | 
|  | vk_fused_0 = T.axis.reduce(256, k_fused_0) | 
|  | C_rf_local[vk_fused_1, 0, 0, v_i2] = C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 16 + vk_fused_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_fused_0 * 16 + vk_fused_1, v_i2 // 8], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[vk_fused_0 * 16 + vk_fused_1, v_i2 // 32]) | 
|  | for ax1_ax2_ax3_fused_0 in T.thread_binding(16, thread="threadIdx.x"): | 
|  | for ax1_ax2_ax3_fused_1 in range(8): | 
|  | for ax0_fused in T.thread_binding(16, thread="threadIdx.y"): | 
|  | with T.block("matmul"): | 
|  | vk_fused_1 = T.axis.reduce(16, ax0_fused) | 
|  | v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 * 128 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1) | 
|  | with T.init(): | 
|  | C[0, 0, v_i2] = T.float16(0) | 
|  | C[0, 0, v_i2] = C[0, 0, v_i2] + C_rf_local[vk_fused_1, 0, 0, v_i2] | 
|  |  | 
|  | # fmt: on | 
|  |  | 
|  | target = Target("nvidia/geforce-rtx-3090-ti") | 
|  | with target: | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, After) | 
|  |  | 
|  |  | 
|  | def test_decode_gemv_sigmoid(): | 
|  | # NK layout + K as decode dim | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Before: | 
|  | @T.prim_func | 
|  | def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), D: T.Buffer((1, 1, 4096), "float16")): | 
|  | T.func_attr({"global_symbol": "main", "tir.noalias": True}) | 
|  | # with T.block("root"): | 
|  | B = T.alloc_buffer((4096, 4096), "float16") | 
|  | C = T.alloc_buffer((1, 1, 4096), "float16") | 
|  | for i, j in T.grid(4096, 4096): | 
|  | with T.block("decode"): | 
|  | v_i, v_j = T.axis.remap("SS", [i, j]) | 
|  | T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32]) | 
|  | T.writes(B[v_i, v_j]) | 
|  | B[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32] | 
|  | for i0, i1, i2, k in T.grid(1, 1, 4096, 4096): | 
|  | with T.block("matmul"): | 
|  | v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) | 
|  | T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k]) | 
|  | T.writes(C[v_i0, v_i1, v_i2]) | 
|  | with T.init(): | 
|  | C[v_i0, v_i1, v_i2] = T.float16(0) | 
|  | C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + V[v_i0, v_i1, v_k] * B[v_i2, v_k] | 
|  | for i0, i1, i2 in T.grid(1, 1, 4096): | 
|  | with T.block("sigmoid"): | 
|  | v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) | 
|  | T.reads(C[v_i0, v_i1, v_i2]) | 
|  | T.writes(D[v_i0, v_i1, v_i2]) | 
|  | D[v_i0, v_i1, v_i2] = T.sigmoid(C[v_i0, v_i1, v_i2]) | 
|  |  | 
|  | @I.ir_module | 
|  | class After: | 
|  | @T.prim_func | 
|  | def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, D_handle: T.handle): | 
|  | T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) | 
|  | W = T.match_buffer(W_handle, (4096, 512), "uint32") | 
|  | S = T.match_buffer(S_handle, (4096, 128), "float16") | 
|  | V = T.match_buffer(V_handle, (1, 1, 4096), "float16") | 
|  | D = T.match_buffer(D_handle, (1, 1, 4096), "float16") | 
|  | with T.block("root"): | 
|  | T.reads() | 
|  | T.writes() | 
|  | C_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local") | 
|  | C_rf_local = T.alloc_buffer((512, 1, 1, 4096), "float16", scope="local") | 
|  | for ax0_fused in T.thread_binding(4096, thread="blockIdx.x"): | 
|  | for ax1_0_fused_1 in T.thread_binding(512, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | 
|  | with T.block("matmul_rf_init"): | 
|  | vax1_0_fused_1 = T.axis.spatial(512, ax1_0_fused_1) | 
|  | v0 = T.axis.spatial(4096, ax0_fused) | 
|  | T.reads() | 
|  | T.writes(C_rf_local[vax1_0_fused_1, 0, 0, v0]) | 
|  | C_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0) | 
|  | for ax1_0_fused_0 in range(1): | 
|  | for ax1_1 in range(8): | 
|  | with T.block("matmul_rf_update"): | 
|  | vax1_0_fused_1 = T.axis.spatial(512, ax1_0_fused_1) | 
|  | v0 = T.axis.spatial(4096, ax0_fused) | 
|  | vax1_0_fused_0 = T.axis.reduce(1, ax1_0_fused_0) | 
|  | vax1_1 = T.axis.reduce(8, ax1_1) | 
|  | T.reads(C_rf_local[vax1_0_fused_1, 0, 0, v0], V[0, 0, vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1], W[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 8], S[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 32]) | 
|  | T.writes(C_rf_local[vax1_0_fused_1, 0, 0, v0]) | 
|  | C_rf_local[vax1_0_fused_1, 0, 0, v0] = C_rf_local[vax1_0_fused_1, 0, 0, v0] + V[0, 0, vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 32]) | 
|  | for ax1_fused in range(1): | 
|  | for ax0 in T.thread_binding(512, thread="threadIdx.x"): | 
|  | with T.block("matmul"): | 
|  | vax1_0_fused_1 = T.axis.reduce(512, ax0) | 
|  | v0 = T.axis.spatial(4096, ax0_fused) | 
|  | T.reads(C_rf_local[vax1_0_fused_1, 0, 0, v0]) | 
|  | T.writes(C_local[0, 0, v0]) | 
|  | with T.init(): | 
|  | C_local[0, 0, v0] = T.float16(0) | 
|  | C_local[0, 0, v0] = C_local[0, 0, v0] + C_rf_local[vax1_0_fused_1, 0, 0, v0] | 
|  | for ax0 in range(1): | 
|  | with T.block("sigmoid"): | 
|  | v0 = T.axis.spatial(4096, ax0_fused + ax0) | 
|  | T.reads(C_local[0, 0, v0]) | 
|  | T.writes(D[0, 0, v0]) | 
|  | D[0, 0, v0] = T.sigmoid(C_local[0, 0, v0]) | 
|  |  | 
|  | # fmt: on | 
|  |  | 
|  | target = Target("nvidia/geforce-rtx-3090-ti") | 
|  | with target: | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, After) | 
|  |  | 
|  |  | 
|  | def test_decode_gemv_1_fp32(): | 
|  | # NK layout + K as decode dim | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Before: | 
|  | @T.prim_func | 
|  | def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): | 
|  | T.func_attr({"global_symbol": "main", "tir.noalias": True}) | 
|  | # with T.block("root"): | 
|  | B = T.alloc_buffer((4096, 4096), "float16") | 
|  | C_fp32 = T.alloc_buffer((1, 1, 4096), "float32") | 
|  | for i, j in T.grid(4096, 4096): | 
|  | with T.block("decode"): | 
|  | v_i, v_j = T.axis.remap("SS", [i, j]) | 
|  | T.reads(W[v_i, v_j // 8], S[v_i, v_j // 32]) | 
|  | T.writes(B[v_i, v_j]) | 
|  | B[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i, v_j // 32] | 
|  | for i0, i1, i2, k in T.grid(1, 1, 4096, 4096): | 
|  | with T.block("matmul"): | 
|  | v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) | 
|  | T.reads(V[v_i0, v_i1, v_k], B[v_i2, v_k]) | 
|  | T.writes(C_fp32[v_i0, v_i1, v_i2]) | 
|  | with T.init(): | 
|  | C_fp32[v_i0, v_i1, v_i2] = T.float16(0) | 
|  | C_fp32[v_i0, v_i1, v_i2] = C_fp32[v_i0, v_i1, v_i2] + T.Cast("float32", V[v_i0, v_i1, v_k]) * T.Cast("float32", B[v_i2, v_k]) | 
|  | for i0, i1, i2 in T.grid(1, 1, 4096): | 
|  | with T.block("cast"): | 
|  | v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) | 
|  | T.reads(C_fp32[v_i0, v_i1, v_i2]) | 
|  | T.writes(C[v_i0, v_i1, v_i2]) | 
|  | C[v_i0, v_i1, v_i2] = T.Cast("float16", C_fp32[v_i0, v_i1, v_i2]) | 
|  |  | 
|  | @I.ir_module | 
|  | class After: | 
|  | @T.prim_func | 
|  | def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T.handle): | 
|  | T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) | 
|  | W = T.match_buffer(W_handle, (4096, 512), "uint32") | 
|  | S = T.match_buffer(S_handle, (4096, 128), "float16") | 
|  | V = T.match_buffer(V_handle, (1, 1, 4096), "float16") | 
|  | C = T.match_buffer(C_handle, (1, 1, 4096), "float16") | 
|  | with T.block("root"): | 
|  | T.reads() | 
|  | T.writes() | 
|  | C_fp32_local = T.alloc_buffer((1, 1, 4096), scope="local") | 
|  | C_fp32_rf_local = T.alloc_buffer((512, 1, 1, 4096), scope="local") | 
|  | for ax0_fused in T.thread_binding(4096, thread="blockIdx.x"): | 
|  | for ax1_0_fused_1 in T.thread_binding(512, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | 
|  | with T.block("matmul_rf_init"): | 
|  | vax1_0_fused_1 = T.axis.spatial(512, ax1_0_fused_1) | 
|  | v0 = T.axis.spatial(4096, ax0_fused) | 
|  | T.reads() | 
|  | T.writes(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0]) | 
|  | C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float32(0) | 
|  | for ax1_0_fused_0 in range(1): | 
|  | for ax1_1 in range(8): | 
|  | with T.block("matmul_rf_update"): | 
|  | vax1_0_fused_1 = T.axis.spatial(512, ax1_0_fused_1) | 
|  | v0 = T.axis.spatial(4096, ax0_fused) | 
|  | vax1_0_fused_0 = T.axis.reduce(1, ax1_0_fused_0) | 
|  | vax1_1 = T.axis.reduce(8, ax1_1) | 
|  | T.reads(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0], V[0, 0, vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1], W[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 8], S[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 32]) | 
|  | T.writes(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0]) | 
|  | C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] = C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] + T.Cast("float32", V[0, 0, vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1]) * T.Cast("float32", (T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v0, (vax1_0_fused_0 * 4096 + vax1_0_fused_1 * 8 + vax1_1) // 32]) | 
|  | for ax1_fused in range(1): | 
|  | for ax0 in T.thread_binding(512, thread="threadIdx.x"): | 
|  | with T.block("matmul"): | 
|  | vax1_0_fused_1 = T.axis.reduce(512, ax0) | 
|  | v0 = T.axis.spatial(4096, ax0_fused) | 
|  | T.reads(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0]) | 
|  | T.writes(C_fp32_local[0, 0, v0]) | 
|  | with T.init(): | 
|  | C_fp32_local[0, 0, v0] = T.float32(0) | 
|  | C_fp32_local[0, 0, v0] = C_fp32_local[0, 0, v0] + C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] | 
|  | for ax0 in range(1): | 
|  | with T.block("cast"): | 
|  | v0 = T.axis.spatial(4096, ax0_fused + ax0) | 
|  | T.reads(C_fp32_local[0, 0, v0]) | 
|  | T.writes(C[0, 0, v0]) | 
|  | C[0, 0, v0] = T.Cast("float16", C_fp32_local[0, 0, v0]) | 
|  |  | 
|  | # fmt: on | 
|  |  | 
|  | target = Target("nvidia/geforce-rtx-3090-ti") | 
|  | with target: | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, After) | 
|  |  | 
|  |  | 
|  | def test_reduction_no_spatial(): | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Before: | 
|  | @T.prim_func | 
|  | def main(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,), "float16"), rms_norm: T.Buffer((1, 4096), "float16")): | 
|  | T.func_attr({"global_symbol": "main", "tir.noalias": True}) | 
|  | Ared_temp = T.alloc_buffer((1, 1)) | 
|  | for ax0 in range(4096): | 
|  | with T.block("Ared_temp"): | 
|  | v0 = T.axis.reduce(4096, ax0) | 
|  | with T.init(): | 
|  | Ared_temp[0, 0] = T.float32(0) | 
|  | Ared_temp[0, 0] = Ared_temp[0, 0] + T.Cast("float32", A[0, 0, v0]) * T.Cast("float32", A[0, 0, v0]) | 
|  | for ax0 in range(4096): | 
|  | with T.block("rms_norm"): | 
|  | v0 = T.axis.spatial(4096, ax0) | 
|  | rms_norm[0, v0] = T.Cast("float16", T.Cast("float32", B[v0]) * (T.Cast("float32", A[0, 0, v0]) / T.sqrt(Ared_temp[0, 0] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)))) | 
|  |  | 
|  | @I.ir_module | 
|  | class After: | 
|  | @T.prim_func | 
|  | def main(A_handle: T.handle, B_handle: T.handle, rms_norm_handle: T.handle): | 
|  | T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) | 
|  | A = T.match_buffer(A_handle, (1, 1, 4096), "float16") | 
|  | B = T.match_buffer(B_handle, (4096,), "float16") | 
|  | rms_norm = T.match_buffer(rms_norm_handle, (1, 4096), "float16") | 
|  | with T.block("root"): | 
|  | T.reads() | 
|  | T.writes() | 
|  | Ared_temp_shared = T.alloc_buffer((1, 1), scope="shared") | 
|  | Ared_temp_rf_local = T.alloc_buffer((1024, 1, 1), scope="local") | 
|  | for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"): | 
|  | for ax1_fused_1 in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | 
|  | with T.block("Ared_temp_rf_init"): | 
|  | vax1_fused_1 = T.axis.spatial(1024, ax1_fused_1) | 
|  | v0 = T.axis.spatial(T.int64(1), T.int64(0)) | 
|  | T.reads() | 
|  | T.writes(Ared_temp_rf_local[vax1_fused_1, 0, 0]) | 
|  | Ared_temp_rf_local[vax1_fused_1, 0, 0] = T.float32(0) | 
|  | for ax1_fused_0 in range(4): | 
|  | for u in range(1): | 
|  | with T.block("Ared_temp_rf_update"): | 
|  | vax1_fused_1 = T.axis.spatial(1024, ax1_fused_1) | 
|  | v0 = T.axis.spatial(T.int64(1), T.int64(0)) | 
|  | vax1_fused_0 = T.axis.reduce(4, ax1_fused_0) | 
|  | T.reads(Ared_temp_rf_local[vax1_fused_1, 0, 0], A[0, 0, vax1_fused_0 * 1024 + vax1_fused_1]) | 
|  | T.writes(Ared_temp_rf_local[vax1_fused_1, 0, 0]) | 
|  | Ared_temp_rf_local[vax1_fused_1, 0, 0] = Ared_temp_rf_local[vax1_fused_1, 0, 0] + T.Cast("float32", A[0, 0, vax1_fused_0 * 1024 + vax1_fused_1]) * T.Cast("float32", A[0, 0, vax1_fused_0 * 1024 + vax1_fused_1]) | 
|  | for ax1_fused in range(T.int64(1)): | 
|  | for ax0 in T.thread_binding(1024, thread="threadIdx.x"): | 
|  | with T.block("Ared_temp"): | 
|  | vax1_fused_1 = T.axis.reduce(1024, ax0) | 
|  | v0 = T.axis.spatial(T.int64(1), T.int64(0)) | 
|  | T.reads(Ared_temp_rf_local[vax1_fused_1, 0, 0]) | 
|  | T.writes(Ared_temp_shared[0, 0]) | 
|  | with T.init(): | 
|  | Ared_temp_shared[0, 0] = T.float32(0) | 
|  | Ared_temp_shared[0, 0] = Ared_temp_shared[0, 0] + Ared_temp_rf_local[vax1_fused_1, 0, 0] | 
|  | for ax0_fused_0 in range(4): | 
|  | for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): | 
|  | with T.block("rms_norm"): | 
|  | v0 = T.axis.spatial(4096, ax0_fused_0 * 1024 + ax0_fused_1) | 
|  | T.reads(B[v0], A[0, 0, v0], Ared_temp_shared[0, 0]) | 
|  | T.writes(rms_norm[0, v0]) | 
|  | rms_norm[0, v0] = T.Cast("float16", T.Cast("float32", B[v0]) * (T.Cast("float32", A[0, 0, v0]) / T.sqrt(Ared_temp_shared[0, 0] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)))) | 
|  | # fmt: on | 
|  | target = Target("nvidia/geforce-rtx-3090-ti") | 
|  | with target: | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, After) | 
|  |  | 
|  |  | 
|  | def test_spatial_inner_no_broadcasting(): | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Module: | 
|  | @T.prim_func | 
|  | def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): | 
|  | T.func_attr({"tir.noalias": True}) | 
|  | p_output0_intermediate_1 = T.alloc_buffer((11008, 4096), "float16") | 
|  | var_matmul_intermediate = T.alloc_buffer((1, 1, 4096), "float16") | 
|  | for i, j in T.grid(11008, 4096): | 
|  | with T.block("decode"): | 
|  | v_i, v_j = T.axis.remap("SS", [i, j]) | 
|  | T.reads(lv575[v_i // 8, v_j], lv576[v_i // 32, v_j]) | 
|  | T.writes(p_output0_intermediate_1[v_i, v_j]) | 
|  | p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[v_i // 32, v_j] | 
|  | for i0, i1, i2, k in T.grid(1, 1, 4096, 11008): | 
|  | with T.block("matmul"): | 
|  | v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) | 
|  | T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) | 
|  | T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) | 
|  | with T.init(): | 
|  | var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) | 
|  | var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] | 
|  | for ax0, ax1, ax2 in T.grid(1, 1, 4096): | 
|  | with T.block("T_add"): | 
|  | v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) | 
|  | T.reads(lv570[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) | 
|  | T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) | 
|  | p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] | 
|  |  | 
|  | @I.ir_module | 
|  | class Expected: | 
|  | @T.prim_func | 
|  | def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): | 
|  | T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) | 
|  | var_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local") | 
|  | var_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local") | 
|  | for ax0_fused_0 in T.thread_binding(256, thread="blockIdx.x"): | 
|  | for ax0_fused_1 in T.thread_binding(16, thread="threadIdx.x"): | 
|  | for ax1_0_fused_1 in T.thread_binding(16, thread="threadIdx.y"): | 
|  | with T.block("matmul_rf_init"): | 
|  | vax1_0_fused_1 = T.axis.spatial(16, ax1_0_fused_1) | 
|  | v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax0_fused_1) | 
|  | T.reads() | 
|  | T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) | 
|  | var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0) | 
|  | for ax1_0_fused_0, ax1_1 in T.grid(86, 8): | 
|  | with T.block("matmul_rf_update"): | 
|  | vax1_0_fused_1 = T.axis.spatial(16, ax1_0_fused_1) | 
|  | v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax0_fused_1) | 
|  | vax1_0_fused_0, vax1_1 = T.axis.remap("RR", [ax1_0_fused_0, ax1_1]) | 
|  | T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0], lv574[0, 0, vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1], lv575[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) // 8, v0], lv576[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) // 32, v0]) | 
|  | T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) | 
|  | var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + lv574[0, 0, vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv575[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) // 8, v0], T.Cast("uint32", (vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[(vax1_0_fused_0 * 128 + vax1_0_fused_1 * 8 + vax1_1) // 32, v0]) | 
|  | for ax1_fused in T.thread_binding(16, thread="threadIdx.x"): | 
|  | for ax0 in T.thread_binding(16, thread="threadIdx.y"): | 
|  | with T.block("matmul"): | 
|  | vax1_0_fused_1 = T.axis.reduce(16, ax0) | 
|  | v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax1_fused) | 
|  | T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) | 
|  | T.writes(var_matmul_intermediate_local[0, 0, v0]) | 
|  | with T.init(): | 
|  | var_matmul_intermediate_local[0, 0, v0] = T.float16(0) | 
|  | var_matmul_intermediate_local[0, 0, v0] = var_matmul_intermediate_local[0, 0, v0] + var_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] | 
|  | for ax0_fused_0_1 in T.thread_binding(16, thread="threadIdx.x"): | 
|  | for ax0_fused_1 in range(1): | 
|  | with T.block("T_add"): | 
|  | v0 = T.axis.spatial(4096, ax0_fused_0 * 16 + ax0_fused_0_1 + ax0_fused_1) | 
|  | T.reads(lv570[0, 0, v0], var_matmul_intermediate_local[0, 0, v0]) | 
|  | T.writes(p_output0_intermediate[0, 0, v0]) | 
|  | p_output0_intermediate[0, 0, v0] = lv570[0, 0, v0] + var_matmul_intermediate_local[0, 0, v0] | 
|  | # fmt: on | 
|  |  | 
|  | target = Target("nvidia/geforce-rtx-3090-ti") | 
|  | with target: | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Module)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, Expected) | 
|  |  | 
|  |  | 
|  | def test_spatial_inner_broadcasting(): | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Module: | 
|  | @T.prim_func | 
|  | def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")): | 
|  | T.func_attr({"tir.noalias": True}) | 
|  | temp_local = T.alloc_buffer((256,)) | 
|  | for j in T.serial(256): | 
|  | for k in T.serial(256): | 
|  | with T.block("sum"): | 
|  | vj, vk = T.axis.remap("SR", [j, k]) | 
|  | T.reads(A[vk, vj]) | 
|  | T.writes(temp_local[vj]) | 
|  | with T.init(): | 
|  | temp_local[vj] = T.float32(0) | 
|  | temp_local[vj] = temp_local[vj] + A[vk, vj] | 
|  | for i, j in T.grid(256, 256): | 
|  | with T.block("add"): | 
|  | vi, vj = T.axis.remap("SS", [i, j]) | 
|  | T.reads(temp_local[vj]) | 
|  | T.writes(B[vi, vj]) | 
|  | B[vi, vj] = A[vi, vj] + temp_local[vj] | 
|  |  | 
|  | @I.ir_module | 
|  | class Expected: | 
|  | @T.prim_func | 
|  | def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")): | 
|  | T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) | 
|  | temp_local_shared = T.alloc_buffer((256,), scope="shared") | 
|  | temp_local_rf_local = T.alloc_buffer((16, 256), scope="local") | 
|  | for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"): | 
|  | for ax0_fused_1 in T.thread_binding(16, thread="threadIdx.x"): | 
|  | for ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): | 
|  | with T.block("sum_rf_init"): | 
|  | vax1_fused_1 = T.axis.spatial(16, ax1_fused_1) | 
|  | v0 = T.axis.spatial(256, ax0_fused_0 * 16 + ax0_fused_1) | 
|  | T.reads() | 
|  | T.writes(temp_local_rf_local[vax1_fused_1, v0]) | 
|  | temp_local_rf_local[vax1_fused_1, v0] = T.float32(0) | 
|  | for ax1_fused_0, u in T.grid(16, 1): | 
|  | with T.block("sum_rf_update"): | 
|  | vax1_fused_1 = T.axis.spatial(16, ax1_fused_1) | 
|  | v0 = T.axis.spatial(256, ax0_fused_0 * 16 + ax0_fused_1) | 
|  | vax1_fused_0 = T.axis.reduce(16, ax1_fused_0) | 
|  | T.reads(temp_local_rf_local[vax1_fused_1, v0], A[vax1_fused_0 * 16 + vax1_fused_1, v0]) | 
|  | T.writes(temp_local_rf_local[vax1_fused_1, v0]) | 
|  | temp_local_rf_local[vax1_fused_1, v0] = temp_local_rf_local[vax1_fused_1, v0] + A[vax1_fused_0 * 16 + vax1_fused_1, v0] | 
|  | for ax1_fused in T.thread_binding(16, thread="threadIdx.x"): | 
|  | for ax0 in T.thread_binding(16, thread="threadIdx.y"): | 
|  | with T.block("sum"): | 
|  | vax1_fused_1 = T.axis.reduce(16, ax0) | 
|  | v0 = T.axis.spatial(256, ax0_fused_0 * 16 + ax1_fused) | 
|  | T.reads(temp_local_rf_local[vax1_fused_1, v0]) | 
|  | T.writes(temp_local_shared[v0]) | 
|  | with T.init(): | 
|  | temp_local_shared[v0] = T.float32(0) | 
|  | temp_local_shared[v0] = temp_local_shared[v0] + temp_local_rf_local[vax1_fused_1, v0] | 
|  | for ax0_ax1_fused_0 in range(16): | 
|  | for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): | 
|  | for ax0_ax1_fused_2 in T.thread_binding(16, thread="threadIdx.y"): | 
|  | with T.block("add"): | 
|  | v0 = T.axis.spatial(256, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) // 16) | 
|  | v1 = T.axis.spatial(256, ax0_fused_0 * 16 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) % 16) | 
|  | T.reads(temp_local_shared[v1]) | 
|  | T.writes(B[v0, v1]) | 
|  | B[v0, v1] = A[v0, v1] + temp_local_shared[v1] | 
|  | # fmt: on | 
|  |  | 
|  | target = Target("nvidia/geforce-rtx-3090-ti") | 
|  | with target: | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Module)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, Expected) | 
|  |  | 
|  |  | 
|  | def test_reduction_inner_no_broadcasting(): | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Module: | 
|  | @T.prim_func | 
|  | def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): | 
|  | T.func_attr({"tir.noalias": True}) | 
|  | temp_local = T.alloc_buffer((256,)) | 
|  | for i in T.serial(256): | 
|  | for k in T.serial(256): | 
|  | with T.block("sum"): | 
|  | vi, vk = T.axis.remap("SR", [i, k]) | 
|  | T.reads(A[vi, vk]) | 
|  | T.writes(temp_local[vi]) | 
|  | with T.init(): | 
|  | temp_local[vi] = T.float32(0) | 
|  | temp_local[vi] = temp_local[vi] + A[vi, vk] | 
|  | for i in T.grid(256): | 
|  | with T.block("add"): | 
|  | vi = T.axis.remap("S", [i]) | 
|  | T.reads(temp_local[vi]) | 
|  | T.writes(B[vi,]) | 
|  | B[vi] = temp_local[vi] + T.float32(1) | 
|  |  | 
|  | @I.ir_module | 
|  | class Expected: | 
|  | @T.prim_func | 
|  | def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): | 
|  | T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) | 
|  | # with T.block("root"): | 
|  | temp_local_local = T.alloc_buffer((256,), scope="local") | 
|  | temp_local_rf_local = T.alloc_buffer((256, 256), scope="local") | 
|  | for ax0_fused in T.thread_binding(256, thread="blockIdx.x"): | 
|  | for ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | 
|  | with T.block("sum_rf_init"): | 
|  | vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, ax0_fused]) | 
|  | T.reads() | 
|  | T.writes(temp_local_rf_local[vax1_fused_1, v0]) | 
|  | temp_local_rf_local[vax1_fused_1, v0] = T.float32(0) | 
|  | for ax1_fused_0, u in T.grid(1, 1): | 
|  | with T.block("sum_rf_update"): | 
|  | vax1_fused_1, v0, vax1_fused_0 = T.axis.remap("SSR", [ax1_fused_1, ax0_fused, ax1_fused_0]) | 
|  | T.reads(temp_local_rf_local[vax1_fused_1, v0], A[v0, vax1_fused_0 * 256 + vax1_fused_1]) | 
|  | T.writes(temp_local_rf_local[vax1_fused_1, v0]) | 
|  | temp_local_rf_local[vax1_fused_1, v0] = temp_local_rf_local[vax1_fused_1, v0] + A[v0, vax1_fused_0 * 256 + vax1_fused_1] | 
|  | for ax1_fused in range(1): | 
|  | for ax0 in T.thread_binding(256, thread="threadIdx.x"): | 
|  | with T.block("sum"): | 
|  | vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused]) | 
|  | T.reads(temp_local_rf_local[vax1_fused_1, v0]) | 
|  | T.writes(temp_local_local[v0]) | 
|  | with T.init(): | 
|  | temp_local_local[v0] = T.float32(0) | 
|  | temp_local_local[v0] = temp_local_local[v0] + temp_local_rf_local[vax1_fused_1, v0] | 
|  | for ax0 in range(1): | 
|  | with T.block("add"): | 
|  | v0 = T.axis.spatial(256, ax0_fused + ax0) | 
|  | T.reads(temp_local_local[v0]) | 
|  | T.writes(B[v0]) | 
|  | B[v0] = temp_local_local[v0] + T.float32(1) | 
|  | # fmt: on | 
|  |  | 
|  | target = Target("nvidia/geforce-rtx-3090-ti") | 
|  | with target: | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Module)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, Expected) | 
|  |  | 
|  |  | 
|  | def test_reduction_inner_no_broadcasting2(): | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Module: | 
|  | @T.prim_func | 
|  | def main(lv9: T.Buffer((2560, 320), "uint32"), lv10: T.Buffer((2560, 80), "float16"), lv1: T.Buffer((1, 2560), "float16"), p_output0_intermediate: T.Buffer((1, 2560), "float32")): | 
|  | T.func_attr({"tir.noalias": True}) | 
|  | # with T.block("root"): | 
|  | p_output0_intermediate_1 = T.alloc_buffer((2560, 2560), "float16") | 
|  | var_matmul_intermediate = T.alloc_buffer((1, 2560), "float16") | 
|  | for i, j in T.grid(2560, 2560): | 
|  | with T.block("decode"): | 
|  | v_i, v_j = T.axis.remap("SS", [i, j]) | 
|  | T.reads(lv9[v_i, v_j // 8], lv10[v_i, v_j // 32]) | 
|  | T.writes(p_output0_intermediate_1[v_i, v_j]) | 
|  | p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv9[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv10[v_i, v_j // 32] | 
|  | for i0, i1, k in T.grid(1, 2560, 2560): | 
|  | with T.block("matmul"): | 
|  | v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) | 
|  | T.reads(lv1[v_i0, v_k], p_output0_intermediate_1[v_k, v_i1]) | 
|  | T.writes(var_matmul_intermediate[v_i0, v_i1]) | 
|  | with T.init(): | 
|  | var_matmul_intermediate[v_i0, v_i1] = T.float16(0) | 
|  | var_matmul_intermediate[v_i0, v_i1] = var_matmul_intermediate[v_i0, v_i1] + lv1[v_i0, v_k] * p_output0_intermediate_1[v_k, v_i1] | 
|  | for i0, i1 in T.grid(1, 2560): | 
|  | with T.block("compute"): | 
|  | v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) | 
|  | T.reads(var_matmul_intermediate[v_i0, v_i1]) | 
|  | T.writes(p_output0_intermediate[v_i0, v_i1]) | 
|  | p_output0_intermediate[v_i0, v_i1] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1]) | 
|  |  | 
|  | @I.ir_module | 
|  | class Expected: | 
|  | @T.prim_func | 
|  | def main(lv9: T.Buffer((2560, 320), "uint32"), lv10: T.Buffer((2560, 80), "float16"), lv1: T.Buffer((1, 2560), "float16"), p_output0_intermediate: T.Buffer((1, 2560), "float32")): | 
|  | T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) | 
|  | # with T.block("root"): | 
|  | var_matmul_intermediate_local = T.alloc_buffer((1, 2560), "float16", scope="local") | 
|  | var_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 2560), "float16", scope="local") | 
|  | for ax0_0_fused_0 in T.thread_binding(20, thread="blockIdx.x"): | 
|  | for ax0_0_fused_1 in T.thread_binding(16, thread="threadIdx.x"): | 
|  | for ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): | 
|  | for ax0_1_init in range(8): | 
|  | with T.block("matmul_rf_init"): | 
|  | vax1_fused_1 = T.axis.spatial(16, ax1_fused_1) | 
|  | v0 = T.axis.spatial(2560, ax0_0_fused_0 * 128 + ax0_0_fused_1 * 8 + ax0_1_init) | 
|  | T.reads() | 
|  | T.writes(var_matmul_intermediate_rf_local[vax1_fused_1, 0, v0]) | 
|  | var_matmul_intermediate_rf_local[vax1_fused_1, 0, v0] = T.float16(0) | 
|  | for ax1_fused_0, ax0_1 in T.grid(160, 8): | 
|  | with T.block("matmul_rf_update"): | 
|  | vax1_fused_1 = T.axis.spatial(16, ax1_fused_1) | 
|  | v0 = T.axis.spatial(2560, ax0_0_fused_0 * 128 + ax0_0_fused_1 * 8 + ax0_1) | 
|  | vax1_fused_0 = T.axis.reduce(160, ax1_fused_0) | 
|  | T.reads(var_matmul_intermediate_rf_local[vax1_fused_1, 0, v0], lv1[0, vax1_fused_0 * 16 + vax1_fused_1], lv9[vax1_fused_0 * 16 + vax1_fused_1, v0 // 8], lv10[vax1_fused_0 * 16 + vax1_fused_1, v0 // 32]) | 
|  | T.writes(var_matmul_intermediate_rf_local[vax1_fused_1, 0, v0]) | 
|  | var_matmul_intermediate_rf_local[vax1_fused_1, 0, v0] = var_matmul_intermediate_rf_local[vax1_fused_1, 0, v0] + lv1[0, vax1_fused_0 * 16 + vax1_fused_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv9[vax1_fused_0 * 16 + vax1_fused_1, v0 // 8], T.Cast("uint32", v0 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv10[vax1_fused_0 * 16 + vax1_fused_1, v0 // 32]) | 
|  | for ax1_fused_0 in T.thread_binding(16, thread="threadIdx.x"): | 
|  | for ax1_fused_1 in range(8): | 
|  | for ax0 in T.thread_binding(16, thread="threadIdx.y"): | 
|  | with T.block("matmul"): | 
|  | vax1_fused_1 = T.axis.reduce(16, ax0) | 
|  | v0 = T.axis.spatial(2560, ax0_0_fused_0 * 128 + ax1_fused_0 * 8 + ax1_fused_1) | 
|  | T.reads(var_matmul_intermediate_rf_local[vax1_fused_1, 0, v0]) | 
|  | T.writes(var_matmul_intermediate_local[0, v0]) | 
|  | with T.init(): | 
|  | var_matmul_intermediate_local[0, v0] = T.float16(0) | 
|  | var_matmul_intermediate_local[0, v0] = var_matmul_intermediate_local[0, v0] + var_matmul_intermediate_rf_local[vax1_fused_1, 0, v0] | 
|  | for ax0_fused_0 in T.thread_binding(16, thread="threadIdx.x"): | 
|  | for ax0_fused_1 in range(8): | 
|  | with T.block("compute"): | 
|  | v0 = T.axis.spatial(2560, ax0_0_fused_0 * 128 + ax0_fused_0 * 8 + ax0_fused_1) | 
|  | T.reads(var_matmul_intermediate_local[0, v0]) | 
|  | T.writes(p_output0_intermediate[0, v0]) | 
|  | p_output0_intermediate[0, v0] = T.Cast("float32", var_matmul_intermediate_local[0, v0]) | 
|  | # fmt: on | 
|  |  | 
|  | with Target("nvidia/geforce-rtx-3090-ti"): | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Module)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, Expected) | 
|  |  | 
|  |  | 
|  | def test_reduction_inner_spatial_choose_perfect_factor(): | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Module: | 
|  | @T.prim_func | 
|  | def main(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(100)), "float16")): | 
|  | T.func_attr({"tir.noalias": True}) | 
|  | n = T.int64() | 
|  | A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") | 
|  | B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(100)), "float16") | 
|  | # with T.block("root"): | 
|  | for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(100), n): | 
|  | with T.block("matmul"): | 
|  | v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) | 
|  | T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) | 
|  | T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) | 
|  | with T.init(): | 
|  | matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) | 
|  | matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] | 
|  | @I.ir_module | 
|  | class Expected: | 
|  | @T.prim_func | 
|  | def main(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(100)), "float16")): | 
|  | T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) | 
|  | n = T.int64() | 
|  | A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") | 
|  | B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(100)), "float16") | 
|  | # with T.block("root"): | 
|  | matmul_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(32), T.int64(1), T.int64(100)), "float16", scope="local") | 
|  | for ax0_ax1_fused_0 in T.thread_binding(T.int64(320), thread="blockIdx.x"): | 
|  | for ax0_ax1_fused_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"): | 
|  | for ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): | 
|  | with T.block("matmul_rf_init"): | 
|  | vax2_fused_1 = T.axis.spatial(T.int64(16), ax2_fused_1) | 
|  | v0 = T.axis.spatial(T.int64(32), (ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) // T.int64(100)) | 
|  | v1 = T.axis.spatial(T.int64(100), (ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) % T.int64(100)) | 
|  | T.reads() | 
|  | T.writes(matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) | 
|  | matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = T.float16(0) | 
|  | for ax2_fused_0, u in T.grid((n + T.int64(15)) // T.int64(16), 1): | 
|  | with T.block("matmul_rf_update"): | 
|  | vax2_fused_1 = T.axis.spatial(T.int64(16), ax2_fused_1) | 
|  | v0 = T.axis.spatial(T.int64(32), (ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) // T.int64(100)) | 
|  | v1 = T.axis.spatial(T.int64(100), (ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) % T.int64(100)) | 
|  | vax2_fused_0 = T.axis.reduce((n + T.int64(15)) // T.int64(16), ax2_fused_0) | 
|  | T.where(ax2_fused_0 * T.int64(16) + ax2_fused_1 < n) | 
|  | T.reads(matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1], A[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1], B[T.int64(0), v0, vax2_fused_0 * T.int64(16) + vax2_fused_1, v1]) | 
|  | T.writes(matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) | 
|  | matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + A[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1] * B[T.int64(0), v0, vax2_fused_0 * T.int64(16) + vax2_fused_1, v1] | 
|  | for ax1_ax2_fused in T.thread_binding(T.int64(10), thread="threadIdx.x"): | 
|  | for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): | 
|  | with T.block("matmul"): | 
|  | vax2_fused_1 = T.axis.reduce(T.int64(16), ax0) | 
|  | v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused_0 // T.int64(10)) | 
|  | v1 = T.axis.spatial(T.int64(100), ax0_ax1_fused_0 % T.int64(10) * T.int64(10) + ax1_ax2_fused) | 
|  | T.reads(matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) | 
|  | T.writes(matmul[T.int64(0), v0, T.int64(0), v1]) | 
|  | with T.init(): | 
|  | matmul[T.int64(0), v0, T.int64(0), v1] = T.float16(0) | 
|  | matmul[T.int64(0), v0, T.int64(0), v1] = matmul[T.int64(0), v0, T.int64(0), v1] + matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] | 
|  | # fmt: on | 
|  |  | 
|  | with Target("nvidia/geforce-rtx-3090-ti"): | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Module)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, Expected) | 
|  |  | 
|  |  | 
|  | def test_repeat_transpose_gemv(): | 
|  | # fmt: off | 
|  |  | 
|  | @I.ir_module | 
|  | class Before: | 
|  | @T.prim_func(private=True) | 
|  | def fused_relax_repeat_relax_permute_dims_relax_matmul1(p_lv716: T.handle, p_astype66: T.handle, var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")): | 
|  | T.func_attr({"tir.noalias": True}) | 
|  | kv_seq_len = T.int64() | 
|  | lv716 = T.match_buffer(p_lv716, (T.int64(1), kv_seq_len, T.int64(8), T.int64(128)), "float16") | 
|  | astype66 = T.match_buffer(p_astype66, (T.int64(1), T.int64(32), T.int64(1), kv_seq_len), "float16") | 
|  | # with T.block("root"): | 
|  | var_T_repeat_intermediate = T.alloc_buffer((T.int64(1), kv_seq_len, T.int64(32), T.int64(128)), "float16") | 
|  | var_T_transpose_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), kv_seq_len, T.int64(128)), "float16") | 
|  | for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), kv_seq_len, T.int64(32), T.int64(128)): | 
|  | with T.block("T_repeat"): | 
|  | v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) | 
|  | T.reads(lv716[v_ax0, v_ax1, v_ax2 // T.int64(4), v_ax3]) | 
|  | T.writes(var_T_repeat_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) | 
|  | var_T_repeat_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv716[v_ax0, v_ax1, v_ax2 // T.int64(4), v_ax3] | 
|  | for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), kv_seq_len, T.int64(128)): | 
|  | with T.block("T_transpose"): | 
|  | v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) | 
|  | T.reads(var_T_repeat_intermediate[v_ax0, v_ax2, v_ax1, v_ax3]) | 
|  | T.writes(var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) | 
|  | var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_T_repeat_intermediate[v_ax0, v_ax2, v_ax1, v_ax3] | 
|  | for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), kv_seq_len): | 
|  | with T.block("matmul"): | 
|  | v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) | 
|  | T.reads(astype66[v_i0, v_i1, v_i2, v_k], var_T_transpose_intermediate[v_i0, v_i1, v_k, v_i3]) | 
|  | T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) | 
|  | with T.init(): | 
|  | var_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) | 
|  | var_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + astype66[v_i0, v_i1, v_i2, v_k] * var_T_transpose_intermediate[v_i0, v_i1, v_k, v_i3] | 
|  | @I.ir_module | 
|  | class Expected: | 
|  | @T.prim_func(private=True) | 
|  | def fused_relax_repeat_relax_permute_dims_relax_matmul1(p_lv716: T.handle, p_astype66: T.handle, var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")): | 
|  | T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) | 
|  | kv_seq_len = T.int64() | 
|  | lv716 = T.match_buffer(p_lv716, (T.int64(1), kv_seq_len, T.int64(8), T.int64(128)), "float16") | 
|  | astype66 = T.match_buffer(p_astype66, (T.int64(1), T.int64(32), T.int64(1), kv_seq_len), "float16") | 
|  | # with T.block("root"): | 
|  | var_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16", scope="local") | 
|  | for ax0_0_ax1_fused_0 in T.thread_binding(T.int64(64), thread="blockIdx.x"): | 
|  | for ax0_0_ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): | 
|  | for ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): | 
|  | for ax0_1_init in range(T.int64(4)): | 
|  | with T.block("matmul_rf_init"): | 
|  | vax2_fused_1 = T.axis.spatial(T.int64(16), ax2_fused_1) | 
|  | v0 = T.axis.spatial(T.int64(32), (ax0_0_ax1_fused_0 * T.int64(16) + ax0_0_ax1_fused_1) // T.int64(128) * T.int64(4) + ax0_1_init) | 
|  | v1 = T.axis.spatial(T.int64(128), (ax0_0_ax1_fused_0 * T.int64(16) + ax0_0_ax1_fused_1) % T.int64(128)) | 
|  | T.reads() | 
|  | T.writes(var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) | 
|  | var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = T.float16(0) | 
|  | for ax2_fused_0, ax0_1 in T.grid((kv_seq_len + T.int64(15)) // T.int64(16), T.int64(4)): | 
|  | with T.block("matmul_rf_update"): | 
|  | vax2_fused_1 = T.axis.spatial(T.int64(16), ax2_fused_1) | 
|  | v0 = T.axis.spatial(T.int64(32), (ax0_0_ax1_fused_0 * T.int64(16) + ax0_0_ax1_fused_1) // T.int64(128) * T.int64(4) + ax0_1) | 
|  | v1 = T.axis.spatial(T.int64(128), (ax0_0_ax1_fused_0 * T.int64(16) + ax0_0_ax1_fused_1) % T.int64(128)) | 
|  | vax2_fused_0 = T.axis.reduce((kv_seq_len + T.int64(15)) // T.int64(16), ax2_fused_0) | 
|  | T.where(ax2_fused_0 * T.int64(16) + ax2_fused_1 < kv_seq_len) | 
|  | T.reads(var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1], astype66[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1], lv716[T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1, v0 // T.int64(4), v1]) | 
|  | T.writes(var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) | 
|  | var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + astype66[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1] * lv716[T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1, v0 // T.int64(4), v1] | 
|  | for ax1_0_ax2_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"): | 
|  | for ax1_1 in range(T.int64(4)): | 
|  | for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): | 
|  | with T.block("matmul"): | 
|  | vax2_fused_1 = T.axis.reduce(T.int64(16), ax0) | 
|  | v0 = T.axis.spatial(T.int64(32), ax0_0_ax1_fused_0 // T.int64(8) * T.int64(4) + ax1_1) | 
|  | v1 = T.axis.spatial(T.int64(128), ax0_0_ax1_fused_0 % T.int64(8) * T.int64(16) + ax1_0_ax2_fused) | 
|  | T.reads(var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) | 
|  | T.writes(var_matmul_intermediate[T.int64(0), v0, T.int64(0), v1]) | 
|  | with T.init(): | 
|  | var_matmul_intermediate[T.int64(0), v0, T.int64(0), v1] = T.float16(0) | 
|  | var_matmul_intermediate[T.int64(0), v0, T.int64(0), v1] = var_matmul_intermediate[T.int64(0), v0, T.int64(0), v1] + var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] | 
|  | # fmt: on | 
|  |  | 
|  | with Target("nvidia/geforce-rtx-3090-ti"): | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, Expected) | 
|  |  | 
|  |  | 
|  | def test_gemv_dyn_shape_epilogue(): | 
|  | @I.ir_module | 
|  | class Module: | 
|  | @T.prim_func(private=True) | 
|  | def main( | 
|  | var_A: T.handle, | 
|  | B: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), | 
|  | var_C: T.handle, | 
|  | ): | 
|  | T.func_attr({"tir.noalias": True}) | 
|  | vocab_size = T.int64() | 
|  | A = T.match_buffer(var_A, (T.int64(4096), vocab_size), "float16") | 
|  | C = T.match_buffer(var_C, (T.int64(1), T.int64(1), vocab_size)) | 
|  | C_temp = T.alloc_buffer((T.int64(1), T.int64(1), vocab_size), "float16") | 
|  | for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), vocab_size, T.int64(4096)): | 
|  | with T.block("matmul"): | 
|  | v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) | 
|  | T.reads(B[v_i0, v_i1, v_k], A[v_k, v_i2]) | 
|  | T.writes(C_temp[v_i0, v_i1, v_i2]) | 
|  | with T.init(): | 
|  | C_temp[v_i0, v_i1, v_i2] = T.float16(0) | 
|  | C_temp[v_i0, v_i1, v_i2] = ( | 
|  | C_temp[v_i0, v_i1, v_i2] + B[v_i0, v_i1, v_k] * A[v_k, v_i2] | 
|  | ) | 
|  | for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), vocab_size): | 
|  | with T.block("epilogue"): | 
|  | v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) | 
|  | T.reads(C_temp[v_i0, v_i1, v_i2]) | 
|  | T.writes(C[v_i0, v_i1, v_i2]) | 
|  | C[v_i0, v_i1, v_i2] = T.Cast("float32", C_temp[v_i0, v_i1, v_i2]) | 
|  |  | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Expected: | 
|  | @T.prim_func(private=True) | 
|  | def main(var_A: T.handle, B: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_C: T.handle): | 
|  | T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) | 
|  | vocab_size = T.int64() | 
|  | A = T.match_buffer(var_A, (T.int64(4096), vocab_size), "float16") | 
|  | C = T.match_buffer(var_C, (T.int64(1), T.int64(1), vocab_size)) | 
|  | # with T.block("root"): | 
|  | C_temp_local = T.alloc_buffer((T.int64(1), T.int64(1), vocab_size), "float16", scope="local") | 
|  | C_temp_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(1), vocab_size), "float16", scope="local") | 
|  | for ax0_fused_0 in T.thread_binding(vocab_size, thread="blockIdx.x"): | 
|  | for ax0_fused_1 in T.thread_binding(T.int64(1), thread="threadIdx.x"): | 
|  | for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): | 
|  | with T.block("matmul_rf_init"): | 
|  | vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1) | 
|  | v0 = T.axis.spatial(vocab_size, ax0_fused_0 + ax0_fused_1) | 
|  | T.reads() | 
|  | T.writes(C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0]) | 
|  | C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0] = T.float16(0) | 
|  | for ax1_fused_0, u in T.grid(T.int64(256), 1): | 
|  | with T.block("matmul_rf_update"): | 
|  | vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1) | 
|  | v0 = T.axis.spatial(vocab_size, ax0_fused_0 + ax0_fused_1) | 
|  | vax1_fused_0 = T.axis.reduce(T.int64(256), ax1_fused_0) | 
|  | T.reads(C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0], B[T.int64(0), T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], A[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]) | 
|  | T.writes(C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0]) | 
|  | C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0] = C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0] + B[T.int64(0), T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * A[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0] | 
|  | for ax1_fused in T.thread_binding(T.int64(1), thread="threadIdx.x"): | 
|  | for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): | 
|  | with T.block("matmul"): | 
|  | vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused_0]) | 
|  | T.reads(C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0]) | 
|  | T.writes(C_temp_local[T.int64(0), T.int64(0), v0]) | 
|  | with T.init(): | 
|  | C_temp_local[T.int64(0), T.int64(0), v0] = T.float16(0) | 
|  | C_temp_local[T.int64(0), T.int64(0), v0] = C_temp_local[T.int64(0), T.int64(0), v0] + C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0] | 
|  | for ax0_fused_0_1 in T.thread_binding(T.int64(1), thread="threadIdx.x"): | 
|  | for ax0_fused_1 in range(T.int64(1)): | 
|  | with T.block("epilogue"): | 
|  | v0 = T.axis.spatial(vocab_size, ax0_fused_0) | 
|  | T.reads(C_temp_local[T.int64(0), T.int64(0), v0]) | 
|  | T.writes(C[T.int64(0), T.int64(0), v0]) | 
|  | C[T.int64(0), T.int64(0), v0] = T.Cast("float32", C_temp_local[T.int64(0), T.int64(0), v0]) | 
|  | # fmt: on | 
|  |  | 
|  | with Target("nvidia/geforce-rtx-3090-ti"): | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Module)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, Expected) | 
|  |  | 
|  |  | 
|  | def test_gemv_output_one_element(): | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Before: | 
|  | @T.prim_func(private=True) | 
|  | def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: T.Buffer((T.int64(1), T.int64(2048)), "float16"), out: T.Buffer((T.int64(1), T.int64(1)), "float16")): | 
|  | T.func_attr({"tir.noalias": True}) | 
|  | NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1)), "float16") | 
|  | for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(2048)): | 
|  | with T.block("NT_matmul"): | 
|  | v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) | 
|  | with T.init(): | 
|  | NT_matmul_intermediate[v_i0, v_i1] = T.float16(0) | 
|  | NT_matmul_intermediate[v_i0, v_i1] = NT_matmul_intermediate[v_i0, v_i1] + A[v_i0, v_k] * weight[v_i1, v_k] | 
|  | for i0, i1 in T.grid(T.int64(1), T.int64(1)): | 
|  | with T.block("compute"): | 
|  | v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) | 
|  | out[v_i0, v_i1] = T.sigmoid(NT_matmul_intermediate[v_i0, v_i1]) | 
|  |  | 
|  |  | 
|  | @I.ir_module | 
|  | class Expected: | 
|  | @T.prim_func(private=True) | 
|  | def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: T.Buffer((T.int64(1), T.int64(2048)), "float16"), out: T.Buffer((T.int64(1), T.int64(1)), "float16")): | 
|  | T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) | 
|  | NT_matmul_intermediate_shared = T.alloc_buffer((T.int64(1), T.int64(1)), "float16", scope="shared") | 
|  | NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(1024), T.int64(1), T.int64(1)), "float16", scope="local") | 
|  | for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"): | 
|  | for ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): | 
|  | with T.block("NT_matmul_rf_init"): | 
|  | vax1_fused_1 = T.axis.spatial(T.int64(1024), ax1_fused_1) | 
|  | v0 = T.axis.spatial(T.int64(1), T.int64(0)) | 
|  | NT_matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] = T.float16(0) | 
|  | for ax1_fused_0, u in T.grid(T.int64(2), 1): | 
|  | with T.block("NT_matmul_rf_update"): | 
|  | vax1_fused_1 = T.axis.spatial(T.int64(1024), ax1_fused_1) | 
|  | v0 = T.axis.spatial(T.int64(1), T.int64(0)) | 
|  | vax1_fused_0 = T.axis.reduce(T.int64(2), ax1_fused_0) | 
|  | NT_matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] = NT_matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] + A[T.int64(0), vax1_fused_0 * T.int64(1024) + vax1_fused_1] * weight[T.int64(0), vax1_fused_0 * T.int64(1024) + vax1_fused_1] | 
|  | for ax1_fused in range(T.int64(1)): | 
|  | for ax0 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): | 
|  | with T.block("NT_matmul"): | 
|  | vax1_fused_1 = T.axis.reduce(T.int64(1024), ax0) | 
|  | v0 = T.axis.spatial(T.int64(1), T.int64(0)) | 
|  | with T.init(): | 
|  | NT_matmul_intermediate_shared[T.int64(0), T.int64(0)] = T.float16(0) | 
|  | NT_matmul_intermediate_shared[T.int64(0), T.int64(0)] = NT_matmul_intermediate_shared[T.int64(0), T.int64(0)] + NT_matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] | 
|  | for ax0_fused_0 in range(T.int64(1)): | 
|  | for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): | 
|  | with T.block("compute"): | 
|  | v0 = T.axis.spatial(T.int64(1), T.int64(0)) | 
|  | T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1)) | 
|  | out[T.int64(0), T.int64(0)] = T.sigmoid(NT_matmul_intermediate_shared[T.int64(0), T.int64(0)]) | 
|  | # fmt: on | 
|  |  | 
|  | with Target("nvidia/geforce-rtx-3090-ti"): | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, Expected) | 
|  |  | 
|  |  | 
|  | def test_no_reduction_loop_check(): | 
|  | # The normalized prime func will not contain a reduction loop since its extent is one. | 
|  | # This checks that the Reduction schedule is correctly not applied in this case | 
|  | # fmt: off | 
|  | @I.ir_module | 
|  | class Before: | 
|  | @T.prim_func(private=True) | 
|  | def matmul(lv43: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16"), lv44: T.Buffer((T.int64(1), T.int64(1), T.int64(1)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16")): | 
|  | T.func_attr({"op_pattern": 4, "tir.noalias": True}) | 
|  | # with T.block("root"): | 
|  | for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(1)): | 
|  | with T.block("matmul"): | 
|  | v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) | 
|  | T.reads(lv43[v_i0, v_i1, v_k], lv44[v_i0, v_k, v_i2]) | 
|  | T.writes(matmul[v_i0, v_i1, v_i2]) | 
|  | with T.init(): | 
|  | matmul[v_i0, v_i1, v_i2] = T.float16(0.0) | 
|  | matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * lv44[v_i0, v_k, v_i2] | 
|  | # fmt: on | 
|  |  | 
|  | target = Target("nvidia/geforce-rtx-3090-ti") | 
|  | with target: | 
|  | mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before)  # pylint: disable=not-callable | 
|  | assert_structural_equal(mod, Before) | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | tvm.testing.main() |