| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=missing-docstring, unused-variable, invalid-name |
| # flake8: noqa: E501 |
| import pytest |
| |
| import tvm.testing |
| from tvm import dlight as dl |
| from tvm.script import tir as T |
| from tvm.target import Target |
| |
| |
| class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): |
| @pytest.fixture |
| def transform(self): |
| def transform(mod): |
| with Target("nvidia/geforce-rtx-2080-ti"): |
| return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod) |
| |
| return transform |
| |
| |
| class TestMatmulTensorize(BaseBeforeAfter): |
| # fmt: off |
| |
| @T.prim_func |
| def before(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float16"), compute: T.Buffer((256, 256), "float16")): |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| # with T.block("root"): |
| for i, j, k in T.grid(256, 256, 256): |
| with T.block("compute"): |
| v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) |
| T.reads(X[v_i, v_k], W[v_j, v_k]) |
| T.writes(compute[v_i, v_j]) |
| with T.init(): |
| compute[v_i, v_j] = T.float16(0) |
| compute[v_i, v_j] = compute[v_i, v_j] + X[v_i, v_k] * W[v_j, v_k] |
| |
| @T.prim_func |
| def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float16"), compute: T.Buffer((256, 256), "float16")): |
| T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) |
| # with T.block("root"): |
| X_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", scope="shared.dyn") |
| W_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", scope="shared.dyn") |
| X_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, 256, 256), "float16", scope="wmma.matrix_a") |
| W_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 256, 256), "float16", scope="wmma.matrix_b") |
| compute_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", scope="shared.dyn") |
| compute_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((1, 256, 256), "float16", scope="wmma.accumulator") |
| for ax0 in T.thread_binding(1, thread="blockIdx.z"): |
| for ax1_0_0_ax2_0_0_fused in T.thread_binding(2, thread="blockIdx.x"): |
| for ax1_0_1_ax2_0_1_fused in T.thread_binding(2, thread="blockIdx.y"): |
| for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): |
| for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): |
| with T.block("compute_o_init"): |
| v0_o = T.axis.spatial(1, ax0) |
| v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) |
| v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) |
| T.reads() |
| T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| with T.block("compute_init_o"): |
| v1_i_init_o = T.axis.spatial(1, 0) |
| v2_i_init_o = T.axis.spatial(1, 0) |
| T.reads() |
| T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| C = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) |
| T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0)) |
| for ax3_0_0 in range(4, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): |
| for ax0_ax1_fused_0 in range(4): |
| for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): |
| for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax0_ax1_fused_3 in T.vectorized(4): |
| with T.block("X_reindex_shared.dyn"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64) |
| v2 = T.axis.spatial(256, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) |
| T.reads(X[v1, v2]) |
| T.writes(X_reindex_shared_dyn[v0, v1, v2]) |
| T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) |
| X_reindex_shared_dyn[v0, v1, v2] = X[v1, v2] |
| for ax0_ax1_fused_0 in range(4): |
| for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): |
| for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax0_ax1_fused_3 in T.vectorized(4): |
| with T.block("W_reindex_shared.dyn"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64) |
| v2 = T.axis.spatial(256, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) |
| T.reads(W[v1, v2]) |
| T.writes(W_reindex_shared_dyn[v0, v1, v2]) |
| T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) |
| W_reindex_shared_dyn[v0, v1, v2] = W[v1, v2] |
| for ax3_0_1 in range(4, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): |
| for ax0_0 in T.unroll(2): |
| for ax1_0 in T.unroll(1): |
| with T.block("X_reindex_shared.dyn_wmma.matrix_a_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) |
| v2_o = T.axis.spatial(16, ax3_0_0 * 4 + ax3_0_1 + ax1_0) |
| T.reads(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| T.writes(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A = T.match_buffer(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) |
| C = T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) |
| T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") |
| for ax0_0 in T.unroll(2): |
| for ax1_0 in T.unroll(1): |
| with T.block("W_reindex_shared.dyn_wmma.matrix_b_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) |
| v2_o = T.axis.spatial(16, ax3_0_0 * 4 + ax3_0_1 + ax1_0) |
| T.reads(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| T.writes(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A = T.match_buffer(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) |
| C = T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) |
| T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major") |
| for ax1_0_3, ax2_0_3 in T.grid(2, 2): |
| with T.block("compute_o_update"): |
| v0_o = T.axis.spatial(1, ax0) |
| v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) |
| v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) |
| v3_o = T.axis.reduce(16, ax3_0_0 * 4 + ax3_0_1) |
| T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) |
| T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| with T.block("compute_o"): |
| v1_i_o = T.axis.spatial(1, 0) |
| v2_i_o = T.axis.spatial(1, 0) |
| v3_i_o = T.axis.reduce(1, 0) |
| T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) |
| T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A = T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) |
| B = T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) |
| C = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) |
| T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) |
| for ax0_0, ax1_0 in T.grid(2, 2): |
| with T.block("compute_reindex_shared.dyn_wmma.accumulator_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) |
| v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) |
| T.reads(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| T.writes(compute_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16) |
| C = T.match_buffer(compute_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=16) |
| T.tvm_store_matrix_sync(A.data, 16, 16, 16, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") |
| for ax0_ax1_fused_0 in range(8): |
| for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax0_ax1_fused_2 in T.vectorized(4): |
| with T.block("compute_reindex_shared.dyn"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) |
| v2 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) |
| T.reads(compute_reindex_shared_dyn[v0, v1, v2]) |
| T.writes(compute[v1, v2]) |
| T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) |
| compute[v1, v2] = compute_reindex_shared_dyn[v0, v1, v2] |
| |
| # fmt: on |
| |
| |
| class TestMatmulTensorizeTooSmall(BaseBeforeAfter): |
| # fmt: off |
| |
| @T.prim_func |
| def before(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.handle): |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| m = T.int32() |
| X = T.match_buffer(var_X, (m, 256), "float16") |
| compute = T.match_buffer(var_compute, (m, 15)) |
| # with T.block("root"): |
| for i, j, k in T.grid(m, 15, 256): |
| with T.block("compute"): |
| v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) |
| T.reads(X[v_i, v_k], W[v_j, v_k]) |
| T.writes(compute[v_i, v_j]) |
| with T.init(): |
| compute[v_i, v_j] = T.float32(0) |
| compute[v_i, v_j] = compute[v_i, v_j] + T.Cast("float32", X[v_i, v_k]) * T.Cast("float32", W[v_j, v_k]) |
| |
| @T.prim_func |
| def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.handle): |
| T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) |
| m = T.int32() |
| X = T.match_buffer(var_X, (m, 256), "float16") |
| compute = T.match_buffer(var_compute, (m, 15)) |
| # with T.block("root"): |
| compute_reindex_pad_local = T.alloc_buffer((1, (m + 31) // 32 * 32, 64), scope="local") |
| X_reindex_pad_shared = T.alloc_buffer((1, (m + 31) // 32 * 32, 256), "float16", scope="shared") |
| W_reindex_pad_shared = T.alloc_buffer((1, 64, 256), "float16", scope="shared") |
| for ax0_ax2_0_fused in T.thread_binding(1, thread="blockIdx.y"): |
| for ax1_0 in T.thread_binding((m + 31) // 32, thread="blockIdx.x"): |
| for ax2_1 in T.thread_binding(1, thread="vthread.y"): |
| for ax1_1 in T.thread_binding(1, thread="vthread.x"): |
| for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): |
| for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): |
| for ax1_3_init, ax2_3_0_init in T.grid(4, 2): |
| for ax2_3_1_init in T.vectorized(2): |
| with T.block("compute_init"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) |
| v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3_0_init * 2 + ax2_3_1_init) |
| T.reads() |
| T.writes(compute_reindex_pad_local[0, v1, v2]) |
| compute_reindex_pad_local[0, v1, v2] = T.float32(0) |
| for ax3_0 in range(16): |
| for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): |
| for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): |
| for ax0_ax1_ax2_fused_2 in range(2): |
| for ax0_ax1_ax2_fused_3 in T.vectorized(2): |
| with T.block("X_reindex_pad_shared"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) |
| v2 = T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) |
| T.reads(X[v1, v2]) |
| T.writes(X_reindex_pad_shared[v0, v1, v2]) |
| T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) |
| X_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, X[v1, v2], T.float16(0)) |
| for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): |
| for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): |
| for ax0_ax1_ax2_fused_2 in range(4): |
| for ax0_ax1_ax2_fused_3 in T.vectorized(2): |
| with T.block("W_reindex_pad_shared"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial(64, (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) |
| v2 = T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) |
| T.reads(W[v1, v2]) |
| T.writes(W_reindex_pad_shared[v0, v1, v2]) |
| T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) |
| W_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 15, W[v1, v2], T.float16(0)) |
| for ax3_1, ax1_3, ax2_3_0 in T.grid(16, 4, 2): |
| for ax2_3_1 in T.vectorized(2): |
| with T.block("compute_update"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) |
| v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3_0 * 2 + ax2_3_1) |
| v3 = T.axis.reduce(256, ax3_0 * 16 + ax3_1) |
| T.reads(compute_reindex_pad_local[0, v1, v2], X_reindex_pad_shared[0, v1, v3], W_reindex_pad_shared[0, v2, v3]) |
| T.writes(compute_reindex_pad_local[0, v1, v2]) |
| compute_reindex_pad_local[0, v1, v2] = compute_reindex_pad_local[0, v1, v2] + T.Cast("float32", X_reindex_pad_shared[0, v1, v3]) * T.Cast("float32", W_reindex_pad_shared[0, v2, v3]) |
| for ax0, ax1, ax2_0 in T.grid(1, 4, 2): |
| for ax2_1_1 in T.vectorized(2): |
| with T.block("compute_reindex_pad_local"): |
| v0 = T.axis.spatial(1, ax0) |
| v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) |
| v2 = T.axis.spatial(64, ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) |
| T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < m and ax2_2 * 4 + ax2_0 * 2 + ax2_1_1 < 15) |
| T.reads(compute_reindex_pad_local[v0, v1, v2]) |
| T.writes(compute[v1, v2]) |
| compute[v1, v2] = compute_reindex_pad_local[v0, v1, v2] |
| # fmt: on |
| |
| |
| class TestMatmulTensorizeEpilogue(BaseBeforeAfter): |
| # fmt: off |
| |
| @T.prim_func |
| def before(lv686: T.Buffer((T.int32(4096), T.int32(256)), "uint32"), lv687: T.Buffer((T.int32(4096), T.int32(64)), "float16"), p_lv42: T.handle, p_lv3: T.handle, p_output0: T.handle): |
| T.func_attr({"tir.noalias": True}) |
| n = T.int32() |
| lv42 = T.match_buffer(p_lv42, (T.int32(1), n, T.int32(2048)), "float16") |
| lv3 = T.match_buffer(p_lv3, (T.int32(1), n, T.int32(4096)), "float16") |
| p_output0_intermediate = T.match_buffer(p_output0, (T.int32(1), n, T.int32(4096)), "float16") |
| # with T.block("root"): |
| p_output0_intermediate_1 = T.alloc_buffer((T.int32(4096), T.int32(2048)), "float16") |
| var_NT_matmul_intermediate = T.alloc_buffer((T.int32(1), n, T.int32(4096)), "float16") |
| var_T_divide_intermediate = T.alloc_buffer((T.int32(1), n, T.int32(4096)), "float16") |
| for i, j in T.grid(T.int32(4096), T.int32(2048)): |
| with T.block("decode"): |
| v_i, v_j = T.axis.remap("SS", [i, j]) |
| T.reads(lv686[v_i, v_j // T.int32(8)], lv687[v_i, v_j // T.int32(32)]) |
| T.writes(p_output0_intermediate_1[v_i, v_j]) |
| p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv686[v_i, v_j // T.int32(8)], T.Cast("uint32", v_j % T.int32(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv687[v_i, v_j // T.int32(32)] |
| for i0, i1, i2, k in T.grid(T.int32(1), n, T.int32(4096), T.int32(2048)): |
| with T.block("NT_matmul"): |
| v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) |
| T.reads(lv42[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k]) |
| T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) |
| with T.init(): |
| var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) |
| var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv42[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k] |
| for ax0, ax1, ax2 in T.grid(T.int32(1), n, T.int32(4096)): |
| with T.block("T_divide"): |
| v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) |
| T.reads(lv3[v_ax0, v_ax1, v_ax2]) |
| T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2]) |
| var_T_divide_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] * T.float16(0.5) |
| for ax0, ax1, ax2 in T.grid(T.int32(1), n, T.int32(4096)): |
| with T.block("T_add"): |
| v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) |
| T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) |
| T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) |
| p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_T_divide_intermediate[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] |
| |
| @T.prim_func |
| def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), "float16"), p_lv42: T.handle, p_lv3: T.handle, p_output0: T.handle): |
| T.func_attr({"global_symbol": "fused_fused_decode3_fused_NT_matmul6_divide1_add1", "tir.is_scheduled": True, "tir.noalias": True}) |
| n = T.int32() |
| lv42 = T.match_buffer(p_lv42, (1, n, 2048), "float16") |
| lv3 = T.match_buffer(p_lv3, (1, n, 4096), "float16") |
| p_output0_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16") |
| # with T.block("root"): |
| lv42_reindex_pad_shared_dyn = T.alloc_buffer((1, (n + 127) // 128 * 128, 2048), "float16", scope="shared.dyn") |
| p_output0_intermediate_1_reindex_shared_dyn = T.alloc_buffer((1, 4096, 2048), "float16", scope="shared.dyn") |
| lv42_reindex_pad_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, (n + 127) // 128 * 128, 2048), "float16", scope="wmma.matrix_a") |
| p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 4096, 2048), "float16", scope="wmma.matrix_b") |
| var_NT_matmul_intermediate_reindex_pad_shared_dyn = T.alloc_buffer((1, (n + 127) // 128 * 128, 4096), "float16", scope="shared.dyn") |
| var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator = T.alloc_buffer((1, (n + 127) // 128 * 128, 4096), "float16", scope="wmma.accumulator") |
| for ax0 in T.thread_binding(1, thread="blockIdx.z"): |
| for ax1_0_0_ax2_0_0_fused in T.thread_binding((n + 127) // 128, thread="blockIdx.x"): |
| for ax1_0_1_ax2_0_1_fused in T.thread_binding(32, thread="blockIdx.y"): |
| for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): |
| for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): |
| with T.block("NT_matmul_o_init"): |
| v0_o = T.axis.spatial(1, ax0) |
| v1_o = T.axis.spatial((n + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) |
| v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) |
| T.reads() |
| T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| with T.block("NT_matmul_init_o"): |
| v1_i_init_o = T.axis.spatial(1, 0) |
| v2_i_init_o = T.axis.spatial(1, 0) |
| T.reads() |
| T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) |
| T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0)) |
| for ax3_0_0 in range(32, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): |
| for ax0_ax1_fused_0 in range(4): |
| for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): |
| for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax0_ax1_fused_3 in T.vectorized(4): |
| with T.block("lv42_reindex_pad_shared.dyn"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64) |
| v2 = T.axis.spatial(2048, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) |
| T.reads(lv42[v0, v1, v2]) |
| T.writes(lv42_reindex_pad_shared_dyn[v0, v1, v2]) |
| T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) |
| lv42_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < n, lv42[v0, v1, v2], T.float16(0)) |
| for ax0_ax1_fused_0 in range(4): |
| for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): |
| for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax0_ax1_fused_3 in T.vectorized(4): |
| with T.block("p_output0_intermediate_1_reindex_shared.dyn"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64) |
| v2 = T.axis.spatial(2048, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) |
| T.reads(lv686[v1, v2 // 8], lv687[v1, v2 // 32]) |
| T.writes(p_output0_intermediate_1_reindex_shared_dyn[v0, v1, v2]) |
| T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) |
| p_output0_intermediate_1_reindex_shared_dyn[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv686[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv687[v1, v2 // 32] |
| for ax3_0_1 in range(4, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): |
| for ax0_0 in T.unroll(2): |
| for ax1_0 in T.unroll(1): |
| with T.block("lv42_reindex_pad_shared.dyn_wmma.matrix_a_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(8 * ((n + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) |
| v2_o = T.axis.spatial(128, ax3_0_0 * 4 + ax3_0_1 + ax1_0) |
| T.reads(lv42_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| T.writes(lv42_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A = T.match_buffer(lv42_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) |
| C = T.match_buffer(lv42_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) |
| T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") |
| for ax0_0 in T.unroll(2): |
| for ax1_0 in T.unroll(1): |
| with T.block("p_output0_intermediate_1_reindex_shared.dyn_wmma.matrix_b_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) |
| v2_o = T.axis.spatial(128, ax3_0_0 * 4 + ax3_0_1 + ax1_0) |
| T.reads(p_output0_intermediate_1_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| T.writes(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) |
| C = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) |
| T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major") |
| for ax1_0_3, ax2_0_3 in T.grid(2, 2): |
| with T.block("NT_matmul_o_update"): |
| v0_o = T.axis.spatial(1, ax0) |
| v1_o = T.axis.spatial((n + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) |
| v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) |
| v3_o = T.axis.reduce(128, ax3_0_0 * 4 + ax3_0_1) |
| T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) |
| T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| with T.block("NT_matmul_o"): |
| v1_i_o = T.axis.spatial(1, 0) |
| v2_i_o = T.axis.spatial(1, 0) |
| v3_i_o = T.axis.reduce(1, 0) |
| T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) |
| T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A = T.match_buffer(lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) |
| B = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) |
| C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) |
| T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) |
| for ax0_0, ax1_0 in T.grid(2, 2): |
| with T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn_wmma.accumulator_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(8 * ((n + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) |
| v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) |
| T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16) |
| C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=16) |
| T.tvm_store_matrix_sync(A.data, 16, 16, 16, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") |
| for ax0_ax1_fused_0 in range(8): |
| for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax0_ax1_fused_2 in T.vectorized(4): |
| with T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) |
| v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) |
| T.where(ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2) // 32 < n) |
| T.reads(lv3[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]) |
| T.writes(p_output0_intermediate[0, v1, v2]) |
| T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) |
| p_output0_intermediate[0, v1, v2] = lv3[0, v1, v2] * T.float16(0.5) + var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2] |
| # fmt: on |
| |
| |
| class TestMatmulInt8Tensorize(BaseBeforeAfter): |
| # fmt: off |
| @T.prim_func |
| def before(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), compute: T.Buffer((256, 256), "int32")): |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| # with T.block("root"): |
| for i, j, r in T.grid(256, 256, 256): |
| with T.block("compute"): |
| v_i, v_j, v_k = T.axis.remap("SSR", [i, j, r]) |
| T.reads(X[v_i, v_k], W[v_j, v_k]) |
| T.writes(compute[v_i, v_j]) |
| with T.init(): |
| compute[v_i, v_j] = 0 |
| compute[v_i, v_j] = compute[v_i, v_j] + T.Cast("int32", X[v_i, v_k]) * T.Cast("int32", W[v_j, v_k]) |
| |
| @T.prim_func |
| def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), compute: T.Buffer((256, 256), "int32")): |
| T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) |
| # with T.block("root"): |
| X_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int8", scope="shared.dyn") |
| W_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int8", scope="shared.dyn") |
| X_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, 256, 256), "int8", scope="wmma.matrix_a") |
| W_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 256, 256), "int8", scope="wmma.matrix_b") |
| compute_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int32", scope="shared.dyn") |
| compute_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((1, 256, 256), "int32", scope="wmma.accumulator") |
| for ax0 in T.thread_binding(1, thread="blockIdx.z"): |
| for ax1_0_0_ax2_0_0_fused in T.thread_binding(2, thread="blockIdx.x"): |
| for ax1_0_1_ax2_0_1_fused in T.thread_binding(2, thread="blockIdx.y"): |
| for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): |
| for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): |
| with T.block("compute_o_init"): |
| v0_o = T.axis.spatial(1, ax0) |
| v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) |
| v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) |
| T.reads() |
| T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| with T.block("compute_init_o"): |
| v1_i_init_o = T.axis.spatial(1, 0) |
| v2_i_init_o = T.axis.spatial(1, 0) |
| T.reads() |
| T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| C = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) |
| T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0)) |
| for ax3_0_0 in T.serial(16, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): |
| for ax0_ax1_fused_0 in range(1): |
| for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): |
| for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax0_ax1_fused_3 in T.vectorized(4): |
| with T.block("X_reindex_shared.dyn"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) |
| v2 = T.axis.spatial(256, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) |
| T.reads(X[v1, v2]) |
| T.writes(X_reindex_shared_dyn[v0, v1, v2]) |
| T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) |
| X_reindex_shared_dyn[v0, v1, v2] = X[v1, v2] |
| for ax0_ax1_fused_0 in range(1): |
| for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): |
| for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax0_ax1_fused_3 in T.vectorized(4): |
| with T.block("W_reindex_shared.dyn"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) |
| v2 = T.axis.spatial(256, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) |
| T.reads(W[v1, v2]) |
| T.writes(W_reindex_shared_dyn[v0, v1, v2]) |
| T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) |
| W_reindex_shared_dyn[v0, v1, v2] = W[v1, v2] |
| for ax3_0_1 in T.serial(1, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): |
| for ax0_0 in T.unroll(2): |
| for ax1_0 in T.unroll(1): |
| with T.block("X_reindex_shared.dyn_wmma.matrix_a_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) |
| v2_o = T.axis.spatial(16, ax3_0_0 + ax1_0) |
| T.reads(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| T.writes(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A = T.match_buffer(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) |
| C = T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) |
| T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") |
| for ax0_0 in T.unroll(2): |
| for ax1_0 in T.unroll(1): |
| with T.block("W_reindex_shared.dyn_wmma.matrix_b_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) |
| v2_o = T.axis.spatial(16, ax3_0_0 + ax1_0) |
| T.reads(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| T.writes(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A = T.match_buffer(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) |
| C = T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) |
| T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major") |
| for ax1_0_3, ax2_0_3 in T.grid(2, 2): |
| with T.block("compute_o_update"): |
| v0_o = T.axis.spatial(1, ax0) |
| v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) |
| v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) |
| v3_o = T.axis.reduce(16, ax3_0_0 + ax3_0_1) |
| T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) |
| T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| with T.block("compute_o"): |
| v1_i_o = T.axis.spatial(1, 0) |
| v2_i_o = T.axis.spatial(1, 0) |
| v3_i_o = T.axis.reduce(1, 0) |
| T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) |
| T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A = T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) |
| B = T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) |
| C = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) |
| T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) |
| for ax0_0, ax1_0 in T.grid(2, 2): |
| with T.block("compute_reindex_shared.dyn_wmma.accumulator_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) |
| v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) |
| T.reads(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| T.writes(compute_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16) |
| C = T.match_buffer(compute_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=16) |
| T.tvm_store_matrix_sync(A.data, 16, 16, 16, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int32"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") |
| for ax0_ax1_fused_0 in range(8): |
| for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax0_ax1_fused_2 in T.vectorized(4): |
| with T.block("compute_reindex_shared.dyn"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) |
| v2 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) |
| T.reads(compute_reindex_shared_dyn[v0, v1, v2]) |
| T.writes(compute[v1, v2]) |
| T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) |
| compute[v1, v2] = compute_reindex_shared_dyn[v0, v1, v2] |
| # fmt: on |
| |
| |
| class TestMatmulInt8Tensorize3d2dDyn(BaseBeforeAfter): |
| # fmt: off |
| @T.prim_func |
| def before(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T.handle): |
| T.func_attr({"op_pattern": 4, "tir.noalias": True}) |
| m = T.int32() |
| A = T.match_buffer(var_A, (1, m, 22016), "int8") |
| matmul_1 = T.match_buffer(var_matmul, (1, m, 4096), "int32") |
| # with T.block("root"): |
| for i0, i1, i2, k in T.grid(1, m, 4096, 22016): |
| with T.block("matmul"): |
| v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) |
| T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k]) |
| T.writes(matmul_1[v_i0, v_i1, v_i2]) |
| with T.init(): |
| matmul_1[v_i0, v_i1, v_i2] = 0 |
| matmul_1[v_i0, v_i1, v_i2] = matmul_1[v_i0, v_i1, v_i2] + T.Cast("int32", A[v_i0, v_i1, v_k]) * T.Cast("int32", B[v_i2, v_k]) |
| |
| @T.prim_func |
| def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T.handle): |
| T.func_attr({"op_pattern": 4, "tir.is_scheduled": True, "tir.noalias": True}) |
| m = T.int32() |
| A = T.match_buffer(var_A, (1, m, 22016), "int8") |
| matmul_1 = T.match_buffer(var_matmul, (1, m, 4096), "int32") |
| # with T.block("root"): |
| A_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128 * 128, 22016), "int8", scope="shared.dyn") |
| B_reindex_shared_dyn = T.alloc_buffer((1, 4096, 22016), "int8", scope="shared.dyn") |
| A_reindex_pad_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, (m + 127) // 128 * 128, 22016), "int8", scope="wmma.matrix_a") |
| B_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 4096, 22016), "int8", scope="wmma.matrix_b") |
| matmul_1_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128 * 128, 4096), "int32", scope="shared.dyn") |
| matmul_1_reindex_pad_shared_dyn_wmma_accumulator = T.alloc_buffer((1, (m + 127) // 128 * 128, 4096), "int32", scope="wmma.accumulator") |
| for ax0 in T.thread_binding(1, thread="blockIdx.z"): |
| for ax1_0_0_ax2_0_0_fused in T.thread_binding((m + 127) // 128, thread="blockIdx.x"): |
| for ax1_0_1_ax2_0_1_fused in T.thread_binding(32, thread="blockIdx.y"): |
| for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): |
| for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): |
| with T.block("matmul_o_init"): |
| v0_o = T.axis.spatial(1, ax0) |
| v1_o = T.axis.spatial((m + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) |
| v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) |
| T.reads() |
| T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| with T.block("matmul_init_o"): |
| v1_i_init_o = T.axis.spatial(1, 0) |
| v2_i_init_o = T.axis.spatial(1, 0) |
| T.reads() |
| T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| C = T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) |
| T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0)) |
| for ax3_0_0 in T.serial(1376, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): |
| for ax0_ax1_fused_0 in range(1): |
| for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): |
| for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax0_ax1_fused_3 in T.vectorized(4): |
| with T.block("A_reindex_pad_shared.dyn"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial((m + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) |
| v2 = T.axis.spatial(22016, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) |
| T.reads(A[v0, v1, v2]) |
| T.writes(A_reindex_pad_shared_dyn[v0, v1, v2]) |
| T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) |
| A_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < m, A[v0, v1, v2], T.int8(0)) |
| for ax0_ax1_fused_0 in range(1): |
| for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): |
| for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax0_ax1_fused_3 in T.vectorized(4): |
| with T.block("B_reindex_shared.dyn"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) |
| v2 = T.axis.spatial(22016, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) |
| T.reads(B[v1, v2]) |
| T.writes(B_reindex_shared_dyn[v0, v1, v2]) |
| T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) |
| B_reindex_shared_dyn[v0, v1, v2] = B[v1, v2] |
| for ax3_0_1 in T.serial(1, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): |
| for ax0_0 in T.unroll(2): |
| for ax1_0 in T.unroll(1): |
| with T.block("A_reindex_pad_shared.dyn_wmma.matrix_a_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(8 * ((m + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) |
| v2_o = T.axis.spatial(1376, ax3_0_0 + ax1_0) |
| T.reads(A_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| T.writes(A_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A_1 = T.match_buffer(A_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) |
| C = T.match_buffer(A_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) |
| T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A_1.data, A_1.elem_offset, A_1.strides[0] * 16, 1), A_1.strides[0], "row_major") |
| for ax0_0 in T.unroll(2): |
| for ax1_0 in T.unroll(1): |
| with T.block("B_reindex_shared.dyn_wmma.matrix_b_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) |
| v2_o = T.axis.spatial(1376, ax3_0_0 + ax1_0) |
| T.reads(B_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| T.writes(B_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A_1 = T.match_buffer(B_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) |
| C = T.match_buffer(B_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) |
| T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A_1.data, A_1.elem_offset, A_1.strides[0] * 16, 1), A_1.strides[0], "col_major") |
| for ax1_0_3, ax2_0_3 in T.grid(2, 2): |
| with T.block("matmul_o_update"): |
| v0_o = T.axis.spatial(1, ax0) |
| v1_o = T.axis.spatial((m + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) |
| v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) |
| v3_o = T.axis.reduce(1376, ax3_0_0 + ax3_0_1) |
| T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) |
| T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| with T.block("matmul_o"): |
| v1_i_o = T.axis.spatial(1, 0) |
| v2_i_o = T.axis.spatial(1, 0) |
| v3_i_o = T.axis.reduce(1, 0) |
| T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) |
| T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A_1 = T.match_buffer(A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) |
| B_1 = T.match_buffer(B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) |
| C = T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) |
| T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A_1.data, A_1.elem_offset // A_1.strides[0] // 16 * (A_1.strides[0] // 16) + A_1.elem_offset % A_1.strides[0] // 16, B_1.data, B_1.elem_offset // B_1.strides[0] // 16 * (B_1.strides[0] // 16) + B_1.elem_offset % B_1.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) |
| for ax0_0, ax1_0 in T.grid(2, 2): |
| with T.block("matmul_1_reindex_pad_shared.dyn_wmma.accumulator_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(8 * ((m + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) |
| v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) |
| T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| T.writes(matmul_1_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) |
| A_1 = T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16) |
| C = T.match_buffer(matmul_1_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=16) |
| T.tvm_store_matrix_sync(A_1.data, 16, 16, 16, A_1.elem_offset // A_1.strides[0] // 16 * (A_1.strides[0] // 16) + A_1.elem_offset % A_1.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int32"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") |
| for ax0_ax1_fused_0 in range(8): |
| for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax0_ax1_fused_2 in T.vectorized(4): |
| with T.block("matmul_1_reindex_pad_shared.dyn"): |
| v0 = T.axis.spatial(1, 0) |
| v1 = T.axis.spatial((m + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) |
| v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) |
| T.where(ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2) // 32 < m) |
| T.reads(matmul_1_reindex_pad_shared_dyn[v0, v1, v2]) |
| T.writes(matmul_1[0, v1, v2]) |
| T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) |
| matmul_1[0, v1, v2] = matmul_1_reindex_pad_shared_dyn[v0, v1, v2] |
| # fmt: on |
| |
| |
| class MetalBeforeAfter(tvm.testing.CompareBeforeAfter): |
| @pytest.fixture |
| def transform(self): |
| def transform(mod): |
| with Target("metal"): |
| return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod) |
| |
| return transform |
| |
| |
| class TestMatmulMetal(MetalBeforeAfter): |
| # fmt: off |
| @T.prim_func(private=True) |
| def before( |
| var_A: T.handle, |
| B: T.Buffer((28672, 4096), "float16"), |
| var_C: T.handle, |
| ): |
| batch_size = T.int32() |
| A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") |
| C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16") |
| for i0, i1, i2, k in T.grid(batch_size, 1, 28672, 4096): |
| with T.block("C"): |
| v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) |
| T.writes(C[v_i0, v_i1, v_i2]) |
| with T.init(): |
| C[v_i0, v_i1, v_i2] = T.float16(0) |
| C[v_i0, v_i1, v_i2] += A[v_i0, v_i1, v_k] * B[v_i2, v_k] |
| |
| @T.prim_func |
| def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.handle): |
| T.func_attr({"tir.is_scheduled": True}) |
| batch_size = T.int32() |
| A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") |
| C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16") |
| # with T.block("root"): |
| A_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 4096), "float16", scope="shared") |
| B_reindex_shared = T.alloc_buffer((1, 28672, 4096), "float16", scope="shared") |
| A_reindex_pad_shared_metal_simdgroup = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 4096), "float16", scope="metal.simdgroup") |
| B_reindex_shared_metal_simdgroup = T.alloc_buffer((1, 4096, 28672), "float16", scope="metal.simdgroup") |
| C_reindex_pad_metal_simdgroup = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 28672), "float16", scope="metal.simdgroup") |
| C_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 28672), "float16", scope="shared") |
| for ax0 in T.thread_binding(1, thread="blockIdx.z"): |
| for ax1_0 in T.thread_binding((batch_size + 15) // 16, thread="blockIdx.x"): |
| for ax2_0 in T.thread_binding(448, thread="blockIdx.y"): |
| for ax1_1 in T.thread_binding(1, thread="threadIdx.y"): |
| for ax2_1 in T.thread_binding(4, thread="threadIdx.z"): |
| for ax1_2_init, ax2_2_init, ax1_3_init_0, ax2_3_init_0 in T.grid(2, 2, 1, 1): |
| with T.block("C_init_o"): |
| v0_o = T.axis.spatial(1, ax0) |
| v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2_init + ax1_3_init_0) |
| v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_2_init + ax2_3_init_0) |
| T.reads() |
| T.writes(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| A_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) |
| T.make_filled_simdgroup_matrix(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.float32(0), 8, 8) |
| for ax3_0 in range(128): |
| for ax0_1, ax1_ax2_fused_0 in T.grid(1, 1): |
| for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): |
| for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): |
| for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax1_ax2_fused_4 in T.vectorized(4): |
| with T.block("A_reindex_pad_shared"): |
| v0 = T.axis.spatial(1, ax0_1) |
| v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32) |
| v2 = T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32) |
| T.reads(A[v1, 0, v2]) |
| T.writes(A_reindex_pad_shared[v0, v1, v2]) |
| A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, A[v1, 0, v2], T.float16(0)) |
| for ax0_1, ax1_ax2_fused_0 in T.grid(1, 4): |
| for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): |
| for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): |
| for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax1_ax2_fused_4 in T.vectorized(4): |
| with T.block("B_reindex_shared"): |
| v0 = T.axis.spatial(1, ax0_1) |
| v1 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32) |
| v2 = T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32) |
| T.reads(B[v1, v2]) |
| T.writes(B_reindex_shared[v0, v1, v2]) |
| B_reindex_shared[v0, v1, v2] = B[v1, v2] |
| for ax3_1 in range(4): |
| for ax0_0, ax1_0_1 in T.grid(2, 1): |
| with T.block("A_reindex_pad_shared_metal.simdgroup_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax0_0) |
| v2_o = T.axis.spatial(512, ax3_0 * 4 + ax3_1 + ax1_0_1) |
| T.reads(A_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| T.writes(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| A_1 = T.match_buffer(A_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", offset_factor=1) |
| C_1 = T.match_buffer(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) |
| T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(False)) |
| for ax0_0, ax1_0_1 in T.grid(2, 1): |
| with T.block("B_reindex_shared_metal.simdgroup_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax0_0) |
| v2_o = T.axis.spatial(512, ax3_0 * 4 + ax3_1 + ax1_0_1) |
| T.reads(B_reindex_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| T.writes(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 8:v1_o * 8 + 8]) |
| A_1 = T.match_buffer(B_reindex_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", offset_factor=1) |
| C_1 = T.match_buffer(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 8:v1_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) |
| T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(True)) |
| for ax1_2, ax2_2 in T.grid(2, 2): |
| with T.block("C_update_o"): |
| v0_o = T.axis.spatial(1, ax0) |
| v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2) |
| v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_2) |
| v3_o = T.axis.reduce(512, ax3_0 * 4 + ax3_1) |
| T.reads(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v3_o * 8:v3_o * 8 + 8], B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| T.writes(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| A_1 = T.match_buffer(A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v3_o * 8:v3_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) |
| B_1 = T.match_buffer(B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("B_s0", "B_s1"), scope="metal.simdgroup", offset_factor=1) |
| C_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) |
| T.simdgroup_multiply_accumulate(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, B_1.data, B_1.elem_offset // B_1.strides[0] // 8 * (B_1.strides[0] // 8) + B_1.elem_offset % B_1.strides[0] // 8, C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8) |
| for ax0_1, ax1_0_1, ax2_0_1 in T.grid(1, 2, 2): |
| with T.block("C_reindex_pad_metal.simdgroup_o"): |
| v0_o = T.axis.spatial(1, ax0_1) |
| v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_0_1) |
| v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_0_1) |
| T.reads(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| T.writes(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| A_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) |
| C_1 = T.match_buffer(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="shared", offset_factor=1) |
| T.simdgroup_store(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), C_1.data, C_1.elem_offset, C_1.strides[0] * 8, 2), C_1.strides[0], 8, 8, T.bool(False)) |
| for ax0_1, ax1_ax2_fused_0 in T.grid(1, 2): |
| for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): |
| for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): |
| for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax1_ax2_fused_4 in T.vectorized(4): |
| with T.block("C_reindex_pad_shared"): |
| v0 = T.axis.spatial(1, ax0_1) |
| v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64) |
| v2 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64) |
| T.where(ax1_0 * 16 + (((ax1_ax2_fused_0 * 4 + ax1_ax2_fused_1 + ax1_ax2_fused_2) * 32 + ax1_ax2_fused_3) * 4 + ax1_ax2_fused_4) // 64 < batch_size) |
| T.reads(C_reindex_pad_shared[v0, v1, v2]) |
| T.writes(C[v1, 0, v2]) |
| C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2] |
| # fmt: on |
| |
| |
| class TestMatmulMetalInt4Quant(MetalBeforeAfter): |
| # fmt: off |
| @T.prim_func(private=True) |
| def before( |
| B0: T.Buffer((28672, 512), "uint32"), |
| B1: T.Buffer((28672, 128), "float16"), |
| var_A: T.handle, |
| var_C: T.handle |
| ): |
| batch_size = T.int32() |
| A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") |
| C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16") |
| compute = T.alloc_buffer((28672, 4096), "float16") |
| B = T.alloc_buffer((28672, 4096), "float16") |
| for i0, i1 in T.grid(28672, 4096): |
| with T.block("compute"): |
| v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) |
| compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(B0[v_i0, v_i1 // 8], T.Cast("uint32", v_i1 % 8 * 4)), T.uint32(15))) |
| for i0, i1 in T.grid(28672, 4096): |
| with T.block("dequantize"): |
| v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) |
| B[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * B1[v_i0, v_i1 // 32] |
| for i0, i1, i2, k in T.grid(batch_size, 1, 28672, 4096): |
| with T.block("NT_matmul"): |
| v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) |
| with T.init(): |
| C[v_i0, v_i1, v_i2] = T.float16(0) |
| C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_i2, v_k] |
| |
| @T.prim_func(private=True) |
| def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "float16"), var_A: T.handle, var_C: T.handle): |
| T.func_attr({"tir.is_scheduled": True}) |
| batch_size = T.int32() |
| A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") |
| C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16") |
| # with T.block("root"): |
| A_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 4096), "float16", scope="shared") |
| B_reindex_shared = T.alloc_buffer((1, 28672, 4096), "float16", scope="shared") |
| A_reindex_pad_shared_metal_simdgroup = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 4096), "float16", scope="metal.simdgroup") |
| B_reindex_shared_metal_simdgroup = T.alloc_buffer((1, 4096, 28672), "float16", scope="metal.simdgroup") |
| C_reindex_pad_metal_simdgroup = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 28672), "float16", scope="metal.simdgroup") |
| C_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 * 16, 28672), "float16", scope="shared") |
| for ax0 in T.thread_binding(1, thread="blockIdx.z"): |
| for ax1_0 in T.thread_binding((batch_size + 15) // 16, thread="blockIdx.x"): |
| for ax2_0 in T.thread_binding(448, thread="blockIdx.y"): |
| for ax1_1 in T.thread_binding(1, thread="threadIdx.y"): |
| for ax2_1 in T.thread_binding(4, thread="threadIdx.z"): |
| for ax1_2_init, ax2_2_init, ax1_3_init_0, ax2_3_init_0 in T.grid(2, 2, 1, 1): |
| with T.block("NT_matmul_init_o"): |
| v0_o = T.axis.spatial(1, ax0) |
| v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2_init + ax1_3_init_0) |
| v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_2_init + ax2_3_init_0) |
| T.reads() |
| T.writes(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| A_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) |
| T.make_filled_simdgroup_matrix(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.float32(0), 8, 8) |
| for ax3_0 in range(128): |
| for ax0_1, ax1_ax2_fused_0 in T.grid(1, 1): |
| for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): |
| for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): |
| for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax1_ax2_fused_4 in T.vectorized(4): |
| with T.block("A_reindex_pad_shared"): |
| v0 = T.axis.spatial(1, ax0_1) |
| v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32) |
| v2 = T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32) |
| T.reads(A[v1, 0, v2]) |
| T.writes(A_reindex_pad_shared[v0, v1, v2]) |
| A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, A[v1, 0, v2], T.float16(0)) |
| for ax0_1, ax1_ax2_fused_0 in T.grid(1, 4): |
| for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): |
| for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): |
| for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax1_ax2_fused_4 in T.vectorized(4): |
| with T.block("B_reindex_shared"): |
| v0 = T.axis.spatial(1, ax0_1) |
| v1 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32) |
| v2 = T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32) |
| T.reads(B0[v1, v2 // 8], B1[v1, v2 // 32]) |
| T.writes(B_reindex_shared[v0, v1, v2]) |
| B_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(B0[v1, v2 // 8], T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7)) * B1[v1, v2 // 32] |
| for ax3_1 in range(4): |
| for ax0_0, ax1_0_1 in T.grid(2, 1): |
| with T.block("A_reindex_pad_shared_metal.simdgroup_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax0_0) |
| v2_o = T.axis.spatial(512, ax3_0 * 4 + ax3_1 + ax1_0_1) |
| T.reads(A_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| T.writes(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| A_1 = T.match_buffer(A_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", offset_factor=1) |
| C_1 = T.match_buffer(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) |
| T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(False)) |
| for ax0_0, ax1_0_1 in T.grid(2, 1): |
| with T.block("B_reindex_shared_metal.simdgroup_o"): |
| v0_o = T.axis.spatial(1, 0) |
| v1_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax0_0) |
| v2_o = T.axis.spatial(512, ax3_0 * 4 + ax3_1 + ax1_0_1) |
| T.reads(B_reindex_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| T.writes(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 8:v1_o * 8 + 8]) |
| A_1 = T.match_buffer(B_reindex_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", offset_factor=1) |
| C_1 = T.match_buffer(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 8:v1_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) |
| T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(True)) |
| for ax1_2, ax2_2 in T.grid(2, 2): |
| with T.block("NT_matmul_update_o"): |
| v0_o = T.axis.spatial(1, ax0) |
| v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2) |
| v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_2) |
| v3_o = T.axis.reduce(512, ax3_0 * 4 + ax3_1) |
| T.reads(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v3_o * 8:v3_o * 8 + 8], B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| T.writes(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| A_1 = T.match_buffer(A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v3_o * 8:v3_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) |
| B = T.match_buffer(B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("B_s0", "B_s1"), scope="metal.simdgroup", offset_factor=1) |
| C_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) |
| T.simdgroup_multiply_accumulate(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, B.data, B.elem_offset // B.strides[0] // 8 * (B.strides[0] // 8) + B.elem_offset % B.strides[0] // 8, C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8) |
| for ax0_1, ax1_0_1, ax2_0_1 in T.grid(1, 2, 2): |
| with T.block("C_reindex_pad_metal.simdgroup_o"): |
| v0_o = T.axis.spatial(1, ax0_1) |
| v1_o = T.axis.spatial(2 * ((batch_size + 15) // 16), ax1_0 * 2 + ax1_0_1) |
| v2_o = T.axis.spatial(3584, ax2_0 * 8 + ax2_1 * 2 + ax2_0_1) |
| T.reads(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| T.writes(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) |
| A_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) |
| C_1 = T.match_buffer(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="shared", offset_factor=1) |
| T.simdgroup_store(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), C_1.data, C_1.elem_offset, C_1.strides[0] * 8, 2), C_1.strides[0], 8, 8, T.bool(False)) |
| for ax0_1, ax1_ax2_fused_0 in T.grid(1, 2): |
| for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): |
| for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): |
| for ax1_ax2_fused_3 in T.thread_binding(32, thread="threadIdx.x"): |
| for ax1_ax2_fused_4 in T.vectorized(4): |
| with T.block("C_reindex_pad_shared"): |
| v0 = T.axis.spatial(1, ax0_1) |
| v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64) |
| v2 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64) |
| T.where(ax1_0 * 16 + (((ax1_ax2_fused_0 * 4 + ax1_ax2_fused_1 + ax1_ax2_fused_2) * 32 + ax1_ax2_fused_3) * 4 + ax1_ax2_fused_4) // 64 < batch_size) |
| T.reads(C_reindex_pad_shared[v0, v1, v2]) |
| T.writes(C[v1, 0, v2]) |
| C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2] |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |