blob: 7cb01b44e478bf566d8ea86c5a177f7f9dc6e726 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
import tvm
import tvm.testing
from tvm.s_tir import dlight as dl
from tvm.script import tir as T
from tvm.target import Target
def test_matmul():
# fmt: off
@T.prim_func(private=True)
def before(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_matmul: T.handle):
m = T.int64()
inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096)))
matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096)))
for i0, i1, i2, k in T.grid(T.int64(1), m, T.int64(4096), T.int64(4096)):
with T.sblock("matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
with T.init():
matmul[v_i0, v_i1, v_i2] = T.float32(0)
matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] + inp0[v_i0, v_i1, v_k] * inp1[v_k, v_i2]
@T.prim_func(private=True)
def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_matmul: T.handle):
T.func_attr({"tir.is_scheduled": True})
m = T.int64()
inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096)))
matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096)))
# with T.sblock("root"):
matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local")
inp0_reindex_pad_shared = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="shared")
inp1_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), scope="shared")
for ax0_ax2_0_fused in T.thread_binding(T.int64(64), thread="blockIdx.y"):
for ax1_0 in T.thread_binding((m + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(2)):
for ax2_3_1_init in T.vectorized(T.int64(2)):
with T.sblock("matmul_init"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0_init * T.int64(2) + ax2_3_1_init)
T.reads()
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
for ax3_0 in range(T.int64(256)):
for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(T.int64(2)):
for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
with T.sblock("inp0_reindex_pad_shared"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
T.reads(inp0[v0, v1, v2])
T.writes(inp0_reindex_pad_shared[v0, v1, v2])
T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0))
for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
with T.sblock("inp1_reindex_shared"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
T.reads(inp1[v2, v1])
T.writes(inp1_reindex_shared[v0, v1, v2])
T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)):
for ax2_3_1 in T.vectorized(T.int64(2)):
with T.sblock("matmul_update"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0 * T.int64(2) + ax2_3_1)
v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], inp0_reindex_pad_shared[T.int64(0), v1, v3], inp1_reindex_shared[T.int64(0), v2, v3])
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
matmul_reindex_pad_local[T.int64(0), v1, v2] = matmul_reindex_pad_local[T.int64(0), v1, v2] + inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0), v2, v3]
for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(2)):
for ax2_1_1 in T.vectorized(T.int64(2)):
with T.sblock("matmul_reindex_pad_local"):
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < m)
T.reads(matmul_reindex_pad_local[v0, v1, v2])
T.writes(matmul[T.int64(0), v1, v2])
matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
# fmt: on
mod = tvm.IRModule({"main": before})
with Target("nvidia/geforce-gtx-1080-ti"):
mod = dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
tvm.ir.assert_structural_equal(mod["main"], expected)
def test_matmul_int32():
# fmt: off
@T.prim_func(private=True)
def func(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_matmul: T.handle):
m = T.int32()
inp0 = T.match_buffer(var_inp0, (1, m, 4096))
matmul = T.match_buffer(var_matmul, (1, m, 4096))
for i0, i1, i2, k in T.grid(1, m, 4096, 4096):
with T.sblock("matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
with T.init():
matmul[v_i0, v_i1, v_i2] = T.float32(0)
matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] + inp0[v_i0, v_i1, v_k] * inp1[v_k, v_i2]
@T.prim_func(private=True)
def expected(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_matmul: T.handle):
T.func_attr({"tir.is_scheduled": True})
m = T.int32()
inp0 = T.match_buffer(var_inp0, (1, m, 4096))
matmul = T.match_buffer(var_matmul, (1, m, 4096))
# with T.sblock("root"):
matmul_reindex_pad_local = T.alloc_buffer((1, (m + 31) // 32 * 32, 4096), scope="local")
inp0_reindex_pad_shared = T.alloc_buffer((1, (m + 31) // 32 * 32, 4096), scope="shared")
inp1_reindex_shared = T.alloc_buffer((1, 4096, 4096), scope="shared")
for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"):
for ax1_0 in T.thread_binding((m + 31) // 32, thread="blockIdx.x"):
for ax2_1 in T.thread_binding(1, thread="vthread.y"):
for ax1_1 in T.thread_binding(1, thread="vthread.x"):
for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax1_3_init, ax2_3_0_init in T.grid(4, 2):
for ax2_3_1_init in T.vectorized(2):
with T.sblock("matmul_init"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0_init * 2 + ax2_3_1_init)
T.reads()
T.writes(matmul_reindex_pad_local[0, v1, v2])
matmul_reindex_pad_local[0, v1, v2] = T.float32(0)
for ax3_0 in range(256):
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(2):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.sblock("inp0_reindex_pad_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(inp0[v0, v1, v2])
T.writes(inp0_reindex_pad_shared[v0, v1, v2])
T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0))
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(4):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.sblock("inp1_reindex_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(inp1[v2, v1])
T.writes(inp1_reindex_shared[v0, v1, v2])
T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
for ax3_1, ax1_3, ax2_3_0 in T.grid(16, 4, 2):
for ax2_3_1 in T.vectorized(2):
with T.sblock("matmul_update"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0 * 2 + ax2_3_1)
v3 = T.axis.reduce(4096, ax3_0 * 16 + ax3_1)
T.reads(matmul_reindex_pad_local[0, v1, v2], inp0_reindex_pad_shared[0, v1, v3], inp1_reindex_shared[0, v2, v3])
T.writes(matmul_reindex_pad_local[0, v1, v2])
matmul_reindex_pad_local[0, v1, v2] = matmul_reindex_pad_local[0, v1, v2] + inp0_reindex_pad_shared[0, v1, v3] * inp1_reindex_shared[0, v2, v3]
for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
for ax2_1_1 in T.vectorized(2):
with T.sblock("matmul_reindex_pad_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < m)
T.reads(matmul_reindex_pad_local[v0, v1, v2])
T.writes(matmul[0, v1, v2])
matmul[0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
# fmt: on
mod = tvm.IRModule({"main": func})
with Target("nvidia/geforce-gtx-1080-ti"):
mod = dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
tvm.ir.assert_structural_equal(mod["main"], expected)
def test_fused_matmul():
# fmt: off
@T.prim_func(private=True)
def before(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32")):
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)))
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)))
for i, j in T.grid(T.int64(4096), T.int64(4096)):
with T.sblock("decode"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads(W[v_i // T.int64(8), v_j], S[v_i // T.int64(32), v_j])
T.writes(var_decode_intermediate[v_i, v_j])
var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(4096), T.int64(4096)):
with T.sblock("matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(A[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
with T.init():
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(4096)):
with T.sblock("T_add"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(C[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
T.writes(Out[v_ax0, v_ax1, v_ax2])
Out[v_ax0, v_ax1, v_ax2] = C[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
@T.prim_func(private=True)
def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32")):
T.func_attr({"tir.is_scheduled": True})
# with T.sblock("root"):
var_matmul_intermediate_reindex_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="local")
A_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared")
var_decode_intermediate_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), scope="shared")
for ax0_ax2_0_fused in T.thread_binding(T.int64(64), thread="blockIdx.y"):
for ax1_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(2)):
for ax2_3_1_init in T.vectorized(T.int64(2)):
with T.sblock("matmul_init"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0_init * T.int64(2) + ax2_3_1_init)
T.reads()
T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = T.float32(0)
for ax3_0 in range(T.int64(256)):
for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(T.int64(2)):
for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
with T.sblock("A_reindex_shared"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
T.reads(A[v0, v1, v2])
T.writes(A_reindex_shared[v0, v1, v2])
T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
A_reindex_shared[v0, v1, v2] = A[v0, v1, v2]
for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
with T.sblock("var_decode_intermediate_reindex_shared"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
T.reads(W[v2 // T.int64(8), v1], S[v2 // T.int64(32), v1])
T.writes(var_decode_intermediate_reindex_shared[v0, v1, v2])
T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
var_decode_intermediate_reindex_shared[v0, v1, v2] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v2 // T.int64(8), v1], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v2 // T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v2 // T.int64(32), v1], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)):
for ax2_3_1 in T.vectorized(T.int64(2)):
with T.sblock("matmul_update"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0 * T.int64(2) + ax2_3_1)
v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
T.reads(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2], A_reindex_shared[T.int64(0), v1, v3], var_decode_intermediate_reindex_shared[T.int64(0), v2, v3])
T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] + A_reindex_shared[T.int64(0), v1, v3] * var_decode_intermediate_reindex_shared[T.int64(0), v2, v3]
for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(2)):
for ax2_1_1 in T.vectorized(T.int64(2)):
with T.sblock("var_matmul_intermediate_reindex_local"):
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial(T.int64(32), ax1_2 * T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
T.reads(C[T.int64(0), v1, v2], var_matmul_intermediate_reindex_local[v0, v1, v2])
T.writes(Out[T.int64(0), v1, v2])
Out[T.int64(0), v1, v2] = C[T.int64(0), v1, v2] + var_matmul_intermediate_reindex_local[v0, v1, v2]
# fmt: on
mod = tvm.IRModule({"main": before})
with Target("nvidia/geforce-gtx-1080-ti"):
mod = dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
tvm.ir.assert_structural_equal(mod["main"], expected)
def test_skip_gemv():
# fmt: off
@T.prim_func(private=True)
def before(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")):
T.func_attr({"tir.noalias": True})
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)))
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)))
for i, j in T.grid(T.int64(4096), T.int64(4096)):
with T.sblock("decode"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads(W[v_i // T.int64(8), v_j], S[v_i // T.int64(32), v_j])
T.writes(var_decode_intermediate[v_i, v_j])
var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):
with T.sblock("matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(A[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
with T.init():
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
with T.sblock("T_add"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(C[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
T.writes(Out[v_ax0, v_ax1, v_ax2])
Out[v_ax0, v_ax1, v_ax2] = C[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
# fmt: on
expected = before
mod = tvm.IRModule({"main": before})
with Target("nvidia/geforce-gtx-1080-ti"):
mod = dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
tvm.ir.assert_structural_equal(mod["main"], expected)
def test_output_fp32():
# fmt: off
@T.prim_func(private=True)
def before(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Buffer((T.int64(4096), T.int64(128)), "float16"), p_lv48: T.handle, lv13_1: T.Buffer((T.int64(4096),), "float16"), p_lv3: T.handle, p_output0: T.handle):
T.func_attr({"tir.noalias": True})
n = T.int64()
lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(4096)), "float16")
lv3 = T.match_buffer(p_lv3, (T.int64(1), n, T.int64(4096)), "float16")
p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16")
# with T.sblock("root"):
p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
var_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)))
var_compute_intermediate = T.alloc_buffer((T.int64(4096),))
var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)))
var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16")
for i, j in T.grid(T.int64(4096), T.int64(4096)):
with T.sblock("decode"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads(lv13[v_i, v_j // T.int64(8)], lv14[v_i, v_j // T.int64(32)])
T.writes(p_output0_intermediate_1[v_i, v_j])
p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv13[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv14[v_i, v_j // T.int64(32)]
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)):
with T.sblock("matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(lv48[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2])
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
with T.init():
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv48[v_i0, v_i1, v_k]) * T.Cast("float32", p_output0_intermediate_1[v_k, v_i2])
for i0 in range(T.int64(4096)):
with T.sblock("compute"):
v_i0 = T.axis.spatial(T.int64(4096), i0)
T.reads(lv13_1[v_i0])
T.writes(var_compute_intermediate[v_i0])
var_compute_intermediate[v_i0] = T.Cast("float32", lv13_1[v_i0])
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
with T.sblock("T_add"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], var_compute_intermediate[v_ax2])
T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + var_compute_intermediate[v_ax2]
for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(4096)):
with T.sblock("compute_1"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2])
T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2])
var_compute_intermediate_1[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate[v_i0, v_i1, v_i2])
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
with T.sblock("T_add_1"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], lv3[v_ax0, v_ax1, v_ax2])
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv3[v_ax0, v_ax1, v_ax2]
@T.prim_func(private=True)
def expected(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Buffer((T.int64(4096), T.int64(128)), "float16"), p_lv48: T.handle, lv13_1: T.Buffer((T.int64(4096),), "float16"), p_lv3: T.handle, p_output0: T.handle):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
n = T.int64()
lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(4096)), "float16")
lv3 = T.match_buffer(p_lv3, (T.int64(1), n, T.int64(4096)), "float16")
p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16")
# with T.sblock("root"):
var_matmul_intermediate_reindex_pad_local = T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local")
lv48_reindex_pad_shared = T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", scope="shared")
p_output0_intermediate_1_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), "float16", scope="shared")
for ax0_ax2_0_fused in T.thread_binding(T.int64(64), thread="blockIdx.y"):
for ax1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(2)):
for ax2_3_1_init in T.vectorized(T.int64(2)):
with T.sblock("matmul_init"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0_init * T.int64(2) + ax2_3_1_init)
T.reads()
T.writes(var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
for ax3_0 in range(T.int64(256)):
for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(T.int64(2)):
for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
with T.sblock("lv48_reindex_pad_shared"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
T.reads(lv48[v0, v1, v2])
T.writes(lv48_reindex_pad_shared[v0, v1, v2])
T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
lv48_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv48[v0, v1, v2], T.float16(0))
for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
with T.sblock("p_output0_intermediate_1_reindex_shared"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
T.reads(lv13[v2, v1 // T.int64(8)], lv14[v2, v1 // T.int64(32)])
T.writes(p_output0_intermediate_1_reindex_shared[v0, v1, v2])
T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
p_output0_intermediate_1_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv13[v2, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv14[v2, v1 // T.int64(32)]
for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)):
for ax2_3_1 in T.vectorized(T.int64(2)):
with T.sblock("matmul_update"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0 * T.int64(2) + ax2_3_1)
v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
T.reads(var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], lv48_reindex_pad_shared[T.int64(0), v1, v3], p_output0_intermediate_1_reindex_shared[T.int64(0), v2, v3])
T.writes(var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + T.Cast("float32", lv48_reindex_pad_shared[T.int64(0), v1, v3]) * T.Cast("float32", p_output0_intermediate_1_reindex_shared[T.int64(0), v2, v3])
for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(2)):
for ax2_1_1 in T.vectorized(T.int64(2)):
with T.sblock("var_matmul_intermediate_reindex_pad_local"):
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < n)
T.reads(var_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv13_1[v2], lv3[T.int64(0), v1, v2])
T.writes(p_output0_intermediate[T.int64(0), v1, v2])
p_output0_intermediate[T.int64(0), v1, v2] = T.Cast("float16", var_matmul_intermediate_reindex_pad_local[v0, v1, v2] + T.Cast("float32", lv13_1[v2])) + lv3[T.int64(0), v1, v2]
# fmt: on
mod = tvm.IRModule({"main": before})
with Target("nvidia/geforce-gtx-1080-ti"):
mod = dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
tvm.ir.assert_structural_equal(mod["main"], expected)
def test_inline_consumer_chain():
# fmt: off
@T.prim_func(private=True)
def before(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "float16"), p_lv52: T.handle, p_output0: T.handle):
T.func_attr({"tir.noalias": True})
n = T.int64()
lv26 = T.match_buffer(p_lv26, (n, T.int64(2048)), "float16")
lv52 = T.match_buffer(p_lv52, (T.int64(1), n, T.int64(2048)))
var_T_multiply_intermediate = T.match_buffer(p_output0, (n, T.int64(2048)), "float16")
# with T.sblock("root"):
var_NT_matmul_intermediate = T.alloc_buffer((n, T.int64(2048)), "float16")
compute = T.alloc_buffer((n, T.int64(2048)), "float16")
var_T_multiply_intermediate_1 = T.alloc_buffer((n, T.int64(2048)), "float16")
var_T_squeeze_intermediate = T.alloc_buffer((n, T.int64(2048)))
var_compute_intermediate = T.alloc_buffer((n, T.int64(2048)), "float16")
for i0, i1, k in T.grid(n, T.int64(2048), T.int64(2048)):
with T.sblock("NT_matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(lv26[v_i0, v_k], lv9[v_i1, v_k])
T.writes(var_NT_matmul_intermediate[v_i0, v_i1])
with T.init():
var_NT_matmul_intermediate[v_i0, v_i1] = T.float16(0)
var_NT_matmul_intermediate[v_i0, v_i1] = var_NT_matmul_intermediate[v_i0, v_i1] + lv26[v_i0, v_k] * lv9[v_i1, v_k]
for i0, i1 in T.grid(n, T.int64(2048)):
with T.sblock("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(var_NT_matmul_intermediate[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1])
for ax0, ax1 in T.grid(n, T.int64(2048)):
with T.sblock("T_multiply"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1], compute[v_ax0, v_ax1])
T.writes(var_T_multiply_intermediate_1[v_ax0, v_ax1])
var_T_multiply_intermediate_1[v_ax0, v_ax1] = var_NT_matmul_intermediate[v_ax0, v_ax1] * compute[v_ax0, v_ax1]
for ax0, ax1 in T.grid(n, T.int64(2048)):
with T.sblock("T_squeeze"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(lv52[T.int64(0), v_ax0, v_ax1])
T.writes(var_T_squeeze_intermediate[v_ax0, v_ax1])
var_T_squeeze_intermediate[v_ax0, v_ax1] = lv52[T.int64(0), v_ax0, v_ax1]
for i0, i1 in T.grid(n, T.int64(2048)):
with T.sblock("compute_1"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(var_T_squeeze_intermediate[v_i0, v_i1])
T.writes(var_compute_intermediate[v_i0, v_i1])
var_compute_intermediate[v_i0, v_i1] = T.Cast("float16", var_T_squeeze_intermediate[v_i0, v_i1])
for ax0, ax1 in T.grid(n, T.int64(2048)):
with T.sblock("T_multiply_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(var_compute_intermediate[v_ax0, v_ax1], var_T_multiply_intermediate_1[v_ax0, v_ax1])
T.writes(var_T_multiply_intermediate[v_ax0, v_ax1])
var_T_multiply_intermediate[v_ax0, v_ax1] = var_compute_intermediate[v_ax0, v_ax1] * var_T_multiply_intermediate_1[v_ax0, v_ax1]
@T.prim_func(private=True)
def expected(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "float16"), p_lv52: T.handle, p_output0: T.handle):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
n = T.int64()
lv26 = T.match_buffer(p_lv26, (n, T.int64(2048)), "float16")
lv52 = T.match_buffer(p_lv52, (T.int64(1), n, T.int64(2048)))
var_T_multiply_intermediate = T.match_buffer(p_output0, (n, T.int64(2048)), "float16")
# with T.sblock("root"):
var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="local")
lv26_reindex_pad_shared = T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2048)), "float16", scope="shared")
lv9_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(2048), T.int64(2048)), "float16", scope="shared")
for ax0_ax2_0_fused in T.thread_binding(T.int64(32), thread="blockIdx.y"):
for ax1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(2)):
for ax2_3_1_init in T.vectorized(T.int64(2)):
with T.sblock("NT_matmul_init"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0_init * T.int64(2) + ax2_3_1_init)
T.reads()
T.writes(var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float16(0)
for ax3_0 in range(T.int64(128)):
for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(T.int64(2)):
for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
with T.sblock("lv26_reindex_pad_shared"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
T.reads(lv26[v1, v2])
T.writes(lv26_reindex_pad_shared[v0, v1, v2])
T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
lv26_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv26[v1, v2], T.float16(0))
for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
with T.sblock("lv9_reindex_shared"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
v2 = T.axis.spatial(T.int64(2048), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
T.reads(lv9[v1, v2])
T.writes(lv9_reindex_shared[v0, v1, v2])
T.sblock_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
lv9_reindex_shared[v0, v1, v2] = lv9[v1, v2]
for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)):
for ax2_3_1 in T.vectorized(T.int64(2)):
with T.sblock("NT_matmul_update"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0 * T.int64(2) + ax2_3_1)
v3 = T.axis.reduce(T.int64(2048), ax3_0 * T.int64(16) + ax3_1)
T.reads(var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], lv26_reindex_pad_shared[T.int64(0), v1, v3], lv9_reindex_shared[T.int64(0), v2, v3])
T.writes(var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + lv26_reindex_pad_shared[T.int64(0), v1, v3] * lv9_reindex_shared[T.int64(0), v2, v3]
for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(2)):
for ax2_1_1 in T.vectorized(T.int64(2)):
with T.sblock("var_NT_matmul_intermediate_reindex_pad_local"):
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
T.reads(lv52[T.int64(0), v1, v2], var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < n)
T.writes(var_T_multiply_intermediate[v1, v2])
var_T_multiply_intermediate[v1, v2] = T.Cast("float16", lv52[T.int64(0), v1, v2]) * (var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]))
# fmt: on
mod = tvm.IRModule({"main": before})
with Target("nvidia/geforce-gtx-1080-ti"):
mod = dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
tvm.ir.assert_structural_equal(mod["main"], expected)
def test_matmul_android():
# fmt: off
@T.prim_func(private=True)
def before(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_matmul: T.handle):
m = T.int64()
inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096)))
matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096)))
for i0, i1, i2, k in T.grid(T.int64(1), m, T.int64(4096), T.int64(4096)):
with T.sblock("matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
with T.init():
matmul[v_i0, v_i1, v_i2] = T.float32(0)
matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] + inp0[v_i0, v_i1, v_k] * inp1[v_k, v_i2]
@T.prim_func(private=True)
def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_matmul: T.handle):
T.func_attr({"tir.is_scheduled": True})
m = T.int64()
inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096)))
matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096)))
# with T.sblock("root"):
inp0_reindex_pad = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16)))
matmul_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16) * T.int64(16), T.int64(4096)), scope="local")
inp0_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16)), scope="local")
for i0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
for i1_0 in T.thread_binding(((m + T.int64(15)) // T.int64(16) * T.int64(16) + T.int64(63)) // T.int64(64), thread="blockIdx.y"):
for i2_0 in T.thread_binding(T.int64(128), thread="blockIdx.x"):
for i1_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
for i1_2 in T.vectorized(T.int64(16)):
with T.sblock("inp0_reindex_pad"):
v0 = T.axis.spatial(T.int64(1), i0)
v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), i1_0 * T.int64(64) + i1_1 * T.int64(16) + i1_2)
v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(32) + i2_1)
T.where((i1_0 * T.int64(4) + i1_1) * T.int64(16) + i1_2 < (m + T.int64(15)) // T.int64(16) * T.int64(16))
T.reads(inp0[v0, v1, v2])
T.writes(inp0_reindex_pad[v0, v1 // T.int64(16), v2, v1 % T.int64(16)])
inp0_reindex_pad[v0, v1 // T.int64(16), v2, v1 % T.int64(16)] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0))
for i2_0 in T.thread_binding(T.int64(16), thread="blockIdx.x"):
for i0_i1_fused_0 in T.thread_binding(((m + T.int64(15)) // T.int64(16) * T.int64(16) + T.int64(63)) // T.int64(64), thread="blockIdx.y"):
for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
for i0_i1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
for i0_i1_fused_2_init in T.unroll(T.int64(16)):
for i2_2_init in T.vectorized(T.int64(8)):
with T.sblock("matmul_init"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2_init)
v_i2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2_init)
T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1) * T.int64(16) + i0_i1_fused_2_init < (m + T.int64(15)) // T.int64(16) * T.int64(16))
T.reads()
T.writes(matmul_pad_local[v_i0, v_i1, v_i2])
matmul_pad_local[v_i0, v_i1, v_i2] = T.float32(0)
for k_0, k_1 in T.grid(T.int64(128), T.int64(4)):
for k_2 in T.unroll(T.int64(8)):
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):
for ax3 in T.vectorized(T.int64(16)):
with T.sblock("inp0_reindex_pad_local"):
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16), i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 + ax1)
v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2 + ax2)
v3 = T.axis.spatial(T.int64(16), ax3)
T.where(i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 < (m + T.int64(15)) // T.int64(16))
T.reads(inp0_reindex_pad[v0, v1, v2, v3])
T.writes(inp0_reindex_pad_local[v0, v1, v2, v3])
inp0_reindex_pad_local[v0, v1, v2, v3] = inp0_reindex_pad[v0, v1, v2, v3]
for i0_i1_fused_2 in T.unroll(T.int64(16)):
for i2_2 in T.vectorized(T.int64(8)):
with T.sblock("matmul_update"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2)
v_i2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2)
v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2)
T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1) * T.int64(16) + i0_i1_fused_2 < (m + T.int64(15)) // T.int64(16) * T.int64(16))
T.reads(matmul_pad_local[v_i0, v_i1, v_i2], inp0_reindex_pad_local[v_i0, v_i1 // T.int64(16), v_k, v_i1 % T.int64(16)], inp1[v_k, v_i2])
T.writes(matmul_pad_local[v_i0, v_i1, v_i2])
matmul_pad_local[v_i0, v_i1, v_i2] = matmul_pad_local[v_i0, v_i1, v_i2] + inp0_reindex_pad_local[v_i0, v_i1 // T.int64(16), v_k, v_i1 % T.int64(16)] * inp1[v_k, v_i2]
for ax0 in T.unroll(T.int64(16)):
for ax1 in T.vectorized(T.int64(8)):
with T.sblock("matmul_pad"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(m, i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0)
v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1)
T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (m + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < m)
T.reads(matmul_pad_local[v0, v1, v2])
T.writes(matmul[v0, v1, v2])
matmul[v0, v1, v2] = matmul_pad_local[v0, v1, v2]
mod = tvm.IRModule({"main": before})
with Target("opencl", host={"kind": "llvm", "mtriple": "aarch64-linux-android"}):
mod = dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
tvm.ir.assert_structural_equal(mod["main"], expected)
def test_fused_dequant_matmul_android():
# fmt: off
@T.prim_func(private=True)
def before(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130: T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),), "float16"), p_output0: T.handle):
T.func_attr({"tir.noalias": True})
seq_len = T.int64()
rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(4096)), "float16")
T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16")
# with T.sblock("root"):
compute = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16")
dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16")
matmul_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(12288)), "float16")
for i0, i1 in T.grid(T.int64(4096), T.int64(12288)):
with T.sblock("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(lv452[v_i0 // T.int64(8), v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv452[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15)))
for i0, i1 in T.grid(T.int64(4096), T.int64(12288)):
with T.sblock("dequantize"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(compute[v_i0, v_i1], lv453[v_i0 // T.int64(32), v_i1])
T.writes(dequantize_intermediate_intermediate[v_i0, v_i1])
dequantize_intermediate_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * lv453[v_i0 // T.int64(32), v_i1]
for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(12288), T.int64(4096)):
with T.sblock("matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(rms_norm130[v_i0, v_i1, v_k], dequantize_intermediate_intermediate[v_k, v_i2])
T.writes(matmul_intermediate[v_i0, v_i1, v_i2])
with T.init():
matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
matmul_intermediate[v_i0, v_i1, v_i2] = matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm130[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate[v_k, v_i2]
for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(12288)):
with T.sblock("T_add"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(matmul_intermediate[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2])
T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2]
@T.prim_func(private=True)
def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130: T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),), "float16"), p_output0: T.handle):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
seq_len = T.int64()
rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(4096)), "float16")
T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16")
# with T.sblock("root"):
dequantize_intermediate_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16", scope="local")
rms_norm130_reindex_pad = T.alloc_buffer((T.int64(1), (seq_len + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16)), "float16")
matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(15)) // T.int64(16) * T.int64(16), T.int64(12288)), "float16", scope="local")
rms_norm130_reindex_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16)), "float16", scope="local")
lv452_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local")
lv453_local = T.alloc_buffer((T.int64(128), T.int64(12288)), "float16", scope="local")
for i0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
for i1_0 in T.thread_binding(((seq_len + T.int64(15)) // T.int64(16) * T.int64(16) + T.int64(63)) // T.int64(64), thread="blockIdx.y"):
for i2_0 in T.thread_binding(T.int64(128), thread="blockIdx.x"):
for i1_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
for i1_2 in T.vectorized(T.int64(16)):
with T.sblock("rms_norm130_reindex_pad"):
v0 = T.axis.spatial(T.int64(1), i0)
v1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16) * T.int64(16), i1_0 * T.int64(64) + i1_1 * T.int64(16) + i1_2)
v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(32) + i2_1)
T.where((i1_0 * T.int64(4) + i1_1) * T.int64(16) + i1_2 < (seq_len + T.int64(15)) // T.int64(16) * T.int64(16))
T.reads(rms_norm130[v0, v1, v2])
T.writes(rms_norm130_reindex_pad[v0, v1 // T.int64(16), v2, v1 % T.int64(16)])
rms_norm130_reindex_pad[v0, v1 // T.int64(16), v2, v1 % T.int64(16)] = T.if_then_else(v1 < seq_len, rms_norm130[v0, v1, v2], T.float16(0))
for i2_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"):
for i0_i1_fused_0 in T.thread_binding(((seq_len + T.int64(15)) // T.int64(16) * T.int64(16) + T.int64(63)) // T.int64(64), thread="blockIdx.y"):
for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
for i0_i1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
for i0_i1_fused_2_init in T.unroll(T.int64(16)):
for i2_2_init in T.vectorized(T.int64(8)):
with T.sblock("matmul_init"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2_init)
v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2_init)
T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1) * T.int64(16) + i0_i1_fused_2_init < (seq_len + T.int64(15)) // T.int64(16) * T.int64(16))
T.reads()
T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2])
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float16(0)
for k_0 in range(T.int64(128)):
for ax0 in range(T.int64(1)):
for ax1 in T.vectorized(T.int64(8)):
with T.sblock("lv453_local"):
v0 = T.axis.spatial(T.int64(128), k_0 + ax0)
v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1)
T.reads(lv453[v0, v1])
T.writes(lv453_local[v0, v1])
lv453_local[v0, v1] = lv453[v0, v1]
for k_1 in range(T.int64(4)):
for ax0 in range(T.int64(1)):
for ax1 in T.vectorized(T.int64(8)):
with T.sblock("lv452_local"):
v0 = T.axis.spatial(T.int64(512), k_0 * T.int64(4) + k_1 + ax0)
v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1)
T.reads(lv452[v0, v1])
T.writes(lv452_local[v0, v1])
lv452_local[v0, v1] = lv452[v0, v1]
for k_2 in T.unroll(T.int64(8)):
for ax0 in T.vectorized(T.int64(8)):
with T.sblock("dequantize"):
v_i0 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2)
v_i1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
T.reads(lv452_local[v_i0 // T.int64(8), v_i1], lv453_local[v_i0 // T.int64(32), v_i1])
T.writes(dequantize_intermediate_intermediate_local[v_i0, v_i1])
dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv452_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv453_local[v_i0 // T.int64(32), v_i1]
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)):
for ax3 in T.vectorized(T.int64(16)):
with T.sblock("rms_norm130_reindex_pad_local"):
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16), i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 + ax1)
v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2 + ax2)
v3 = T.axis.spatial(T.int64(16), ax3)
T.where(i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 < (seq_len + T.int64(15)) // T.int64(16))
T.reads(rms_norm130_reindex_pad[v0, v1, v2, v3])
T.writes(rms_norm130_reindex_pad_local[v0, v1, v2, v3])
rms_norm130_reindex_pad_local[v0, v1, v2, v3] = rms_norm130_reindex_pad[v0, v1, v2, v3]
for i0_i1_fused_2 in T.unroll(T.int64(16)):
for i2_2 in T.vectorized(T.int64(8)):
with T.sblock("matmul_update"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2)
v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2)
v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2)
T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1) * T.int64(16) + i0_i1_fused_2 < (seq_len + T.int64(15)) // T.int64(16) * T.int64(16))
T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], rms_norm130_reindex_pad_local[v_i0, v_i1 // T.int64(16), v_k, v_i1 % T.int64(16)], dequantize_intermediate_intermediate_local[v_k, v_i2])
T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2])
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm130_reindex_pad_local[v_i0, v_i1 // T.int64(16), v_k, v_i1 % T.int64(16)] * dequantize_intermediate_intermediate_local[v_k, v_i2]
for ax0 in T.unroll(T.int64(16)):
for ax1 in T.vectorized(T.int64(8)):
with T.sblock("T_add"):
v_ax0 = T.axis.spatial(T.int64(1), T.int64(0))
v_ax1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0)
v_ax2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1)
T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (seq_len + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < seq_len)
T.reads(matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2])
T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2]
# fmt: on
mod = tvm.IRModule({"main": before})
with Target("opencl", host={"kind": "llvm", "mtriple": "aarch64-linux-android"}):
mod = dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
tvm.ir.assert_structural_equal(mod["main"], expected)
if __name__ == "__main__":
tvm.testing.main()