blob: 6aea731d5c0250e8c1f03d7cfb63be2491f5ba06 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
import tvm
from tvm import dlight as dl
from tvm.ir import IRModule, assert_structural_equal
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.target import Target
def _check(mod_before: IRModule, mod_after: IRModule):
target = Target("nvidia/geforce-rtx-3090-ti")
with target:
mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable
dl.gpu.Transpose(),
)(mod_before)
assert_structural_equal(mod, mod_after)
def test_transpose():
# fmt: off
@I.ir_module
class Before:
@T.prim_func
def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "float32"), T_transpose: T.Buffer((T.int64(4096), T.int64(512)), "float32")):
T.func_attr({"tir.noalias": True})
for ax0, ax1 in T.grid(T.int64(4096), T.int64(512)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
@I.ir_module
class After:
@T.prim_func
def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "float32"), T_transpose: T.Buffer((T.int64(4096), T.int64(512)), "float32")):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
rxplaceholder_shared = T.alloc_buffer((T.int64(512), T.int64(4096)), scope="shared")
for ax0_0_0 in T.thread_binding(T.int64(512), thread="blockIdx.y", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax1_0 in T.thread_binding(T.int64(32), thread="blockIdx.x"):
for ax0_ax1_fused_0 in range(T.int64(1)):
for ax0_ax1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
for ax0_ax1_fused_2 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0_ax1_fused_3 in T.unroll(T.int64(1)):
with T.block("rxplaceholder_shared"):
v0 = T.axis.spatial(T.int64(512), ax1_0 * T.int64(16) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(16) + ax0_ax1_fused_2 + ax0_ax1_fused_3) // T.int64(8))
v1 = T.axis.spatial(T.int64(4096), ax0_0_0 * T.int64(8) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(16) + ax0_ax1_fused_2 + ax0_ax1_fused_3) % T.int64(8))
T.reads(rxplaceholder[v0, v1])
T.writes(rxplaceholder_shared[v0, v1])
T.block_attr({"buffer_dim_align": [[0, 0, 32, 1]]})
rxplaceholder_shared[v0, v1] = rxplaceholder[v0, v1]
for ax0_0_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
for ax1_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0_1_0 in range(T.int64(1)):
for ax0_1_1 in range(T.int64(1)):
with T.block("T_transpose"):
v0 = T.axis.spatial(T.int64(4096), ax0_0_0 * T.int64(8) + ax0_0_1 + ax0_1_0 + ax0_1_1)
v1 = T.axis.spatial(T.int64(512), ax1_0 * T.int64(16) + ax1_1)
T.reads(rxplaceholder_shared[v1, v0])
T.writes(T_transpose[v0, v1])
T_transpose[v0, v1] = rxplaceholder_shared[v1, v0]
# fmt: on
_check(Before, After)
def test_decode_transpose():
# fmt: off
@I.ir_module
class Before:
@T.prim_func
def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float32")):
T.func_attr({"tir.noalias": True})
decode = T.alloc_buffer((T.int64(4096), T.int64(4096)))
for i, j in T.grid(T.int64(4096), T.int64(4096)):
with T.block("decode"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads(rxplaceholder[v_i // T.int64(8), v_j], rxplaceholder_1[v_i // T.int64(32), v_j])
T.writes(decode[v_i, v_j])
decode[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(decode[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
@I.ir_module
class After:
@T.prim_func
def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float32")):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
decode_shared = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope="shared")
for ax0_0_0 in T.thread_binding(T.int64(64), thread="blockIdx.y", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax1_0 in T.thread_binding(T.int64(256), thread="blockIdx.x"):
for ax0_ax1_fused_0 in range(T.int64(1)):
for ax0_ax1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
for ax0_ax1_fused_2 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0_ax1_fused_3 in T.unroll(T.int64(8)):
with T.block("decode_shared"):
v0 = T.axis.spatial(T.int64(4096), ax1_0 * T.int64(16) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(128) + ax0_ax1_fused_2 * T.int64(8) + ax0_ax1_fused_3) // T.int64(64))
v1 = T.axis.spatial(T.int64(4096), ax0_0_0 * T.int64(64) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(128) + ax0_ax1_fused_2 * T.int64(8) + ax0_ax1_fused_3) % T.int64(64))
T.reads(rxplaceholder[v0 // T.int64(8), v1], rxplaceholder_1[v0 // T.int64(32), v1])
T.writes(decode_shared[v0, v1])
T.block_attr({"buffer_dim_align": [[0, 0, 32, 1]]})
decode_shared[v0, v1] = T.Cast("float32", T.bitwise_and(T.shift_right(rxplaceholder[v0 // T.int64(8), v1], T.Cast("uint32", v0 % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(rxplaceholder_1[v0 // T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(rxplaceholder_1[v0 // T.int64(32), v1], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
for ax0_0_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
for ax1_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0_1_0 in range(T.int64(2)):
for ax0_1_1 in T.vectorized(T.int64(4)):
with T.block("T_transpose"):
v0 = T.axis.spatial(T.int64(4096), ax0_0_0 * T.int64(64) + ax0_0_1 * T.int64(8) + ax0_1_0 * T.int64(4) + ax0_1_1)
v1 = T.axis.spatial(T.int64(4096), ax1_0 * T.int64(16) + ax1_1)
T.reads(decode_shared[v1, v0])
T.writes(T_transpose[v0, v1])
T_transpose[v0, v1] = decode_shared[v1, v0]
# fmt: on
_check(Before, After)
def test_decode_int3_transpose():
# fmt: off
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), B: T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")):
T.func_attr({"tir.noalias": True})
decode_1 = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
for i, j in T.grid(T.int64(4096), T.int64(4096)):
with T.block("decode"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads(A[v_i // T.int64(10), v_j], B[v_i // T.int64(40), v_j])
T.writes(decode_1[v_i, v_j])
decode_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j]
for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(decode_1[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = decode_1[v_ax1, v_ax0]
@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), B: T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")):
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
decode_1_shared = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16", scope="shared")
for ax0_0_0 in T.thread_binding(T.int64(52), thread="blockIdx.y", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax1_0 in T.thread_binding(T.int64(256), thread="blockIdx.x"):
for ax0_ax1_fused_0 in range(T.int64(2)):
for ax0_ax1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
for ax0_ax1_fused_2 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0_ax1_fused_3 in T.unroll(T.int64(10)):
with T.block("decode_1_shared"):
v0 = T.axis.spatial(T.int64(4096), ax1_0 * T.int64(16) + (ax0_ax1_fused_0 * T.int64(1280) + ax0_ax1_fused_1 * T.int64(160) + ax0_ax1_fused_2 * T.int64(10) + ax0_ax1_fused_3) // T.int64(82))
v1 = T.axis.spatial(T.int64(4096), ax0_0_0 * T.int64(80) + (ax0_ax1_fused_0 * T.int64(1280) + ax0_ax1_fused_1 * T.int64(160) + ax0_ax1_fused_2 * T.int64(10) + ax0_ax1_fused_3) % T.int64(82))
T.where(ax0_0_0 * T.int64(80) + (((ax0_ax1_fused_0 * T.int64(8) + ax0_ax1_fused_1) * T.int64(16) + ax0_ax1_fused_2) * T.int64(10) + ax0_ax1_fused_3) % T.int64(82) < T.int64(4096) and ((ax0_ax1_fused_0 * T.int64(8) + ax0_ax1_fused_1) * T.int64(16) + ax0_ax1_fused_2) * T.int64(10) + ax0_ax1_fused_3 < T.int64(1312))
T.reads(A[v0 // T.int64(10), v1], B[v0 // T.int64(40), v1])
T.writes(decode_1_shared[v0, v1])
T.block_attr({"buffer_dim_align": [[0, 0, 32, 1]]})
decode_1_shared[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v0 // T.int64(10), v1], T.Cast("uint32", v0 % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v0 // T.int64(40), v1]
for ax0_0_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
for ax1_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0_1_0 in range(T.int64(3)):
for ax0_1_1 in T.vectorized(T.int64(4)):
with T.block("T_transpose"):
v0 = T.axis.spatial(T.int64(4096), (ax0_0_0 * T.int64(8) + ax0_0_1) * T.int64(10) + (ax0_1_0 * T.int64(4) + ax0_1_1))
v1 = T.axis.spatial(T.int64(4096), ax1_0 * T.int64(16) + ax1_1)
T.where((ax0_0_0 * T.int64(8) + ax0_0_1) * T.int64(10) + (ax0_1_0 * T.int64(4) + ax0_1_1) < T.int64(4096) and ax0_0_0 * T.int64(8) + ax0_0_1 < T.int64(410) and ax0_1_0 * T.int64(4) + ax0_1_1 < T.int64(10))
T.reads(decode_1_shared[v1, v0])
T.writes(T_transpose[v0, v1])
T_transpose[v0, v1] = decode_1_shared[v1, v0]
# fmt: on
_check(Before, After)