blob: a4eaa3ad748c557b4736108963229dc4ce76e84b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
import tvm.testing
from tvm import dlight as dl
from tvm.ir import assert_structural_equal
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.target import Target
def test_fallback():
@I.ir_module
class Before:
@T.prim_func
def main(
A: T.Buffer((1, 32, 1, 128), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
B = T.alloc_buffer((1, 1, 32, 128), "float16")
for i, j, k, l in T.grid(1, 1, 32, 128):
with T.block("T_transpose"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
B[vi, vj, vk, vl] = A[vi, vk, vj, vl]
for i, j, k in T.grid(1, 1, 4096):
with T.block("T_reshape"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128]
@I.ir_module
class After:
@T.prim_func
def main(
A: T.Buffer((1, 32, 1, 128), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
T.func_attr({"tir.is_scheduled": True})
for ax0_fused_0 in T.thread_binding(4, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("T_reshape"):
v0 = T.axis.spatial(4096, ax0_fused_0 * 1024 + ax0_fused_1)
T.reads(A[0, v0 // 128, 0, v0 % 128])
T.writes(C[0, 0, v0])
C[0, 0, v0] = A[0, v0 // 128, 0, v0 % 128]
target = Target("nvidia/geforce-rtx-3090-ti")
with target:
mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable
dl.gpu.Fallback(),
)(Before)
assert_structural_equal(mod, After)
def test_fallback_reduction():
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1, 6144), "float32"), B: T.Buffer((1,), "float32")):
for ax0, ax1 in T.grid(1, 6144):
with T.block("block"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.reduce(6144, ax1)
T.reads(A[v0, v1])
T.writes(B[v0])
with T.init():
B[v0] = T.float32(0)
B[v0] = B[v0] + T.Cast("float32", A[v0, v1])
@I.ir_module
class Expected:
@T.prim_func
def main(A: T.Buffer((1, 6144), "float32"), B: T.Buffer((1,), "float32")):
T.func_attr({"tir.is_scheduled": True})
for ax0_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
with T.block("block_init"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1))
T.reads()
T.writes(B[0])
B[0] = T.float32(0)
for ax1 in range(6144):
with T.block("block_update"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.reduce(6144, ax1)
T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1))
T.reads(B[0], A[0, v1])
T.writes(B[0])
B[0] = B[0] + T.Cast("float32", A[0, v1])
with Target("apple/m1-gpu"):
mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable
dl.gpu.Fallback(),
)(Module)
assert_structural_equal(mod, Expected)
def test_fallback_irregular_spatial():
@T.prim_func(private=True)
def func(
var_pages: T.handle,
var_page_table_indptr: T.handle,
var_page_table_values: T.handle,
var_values: T.handle,
seq_id: T.int32,
):
nhead = T.int32()
nlayer = T.int32()
seqlen = T.int32()
npage = T.int32()
page_size = T.int32()
num_total_pages = T.int32()
num_total_seqs_plus_1 = T.int32()
pages = T.match_buffer(var_pages, (num_total_pages, nlayer, nhead, page_size), "float16")
page_table_indptr = T.match_buffer(var_page_table_indptr, (num_total_seqs_plus_1,), "int32")
page_table_values = T.match_buffer(var_page_table_values, (npage,), "int32")
values = T.match_buffer(var_values, (nlayer, nhead, seqlen), "float16")
for l, h, pos in T.grid(nlayer, nhead, seqlen):
with T.block("block"):
vl, vh, vp = T.axis.remap("SSS", [l, h, pos])
values[vl, vh, vp] = pages[
page_table_values[page_table_indptr[seq_id] + T.floordiv(vp, page_size)],
vl,
vh,
T.floormod(vp, page_size),
]
# fmt: off
@T.prim_func(private=True)
def expected(var_pages: T.handle, var_page_table_indptr: T.handle, var_page_table_values: T.handle, var_values: T.handle, seq_id: T.int32):
T.func_attr({"tir.is_scheduled": True})
nhead = T.int32()
nlayer = T.int32()
seqlen = T.int32()
npage = T.int32()
page_size = T.int32()
num_total_pages = T.int32()
num_total_seqs_plus_1 = T.int32()
pages = T.match_buffer(var_pages, (num_total_pages, nlayer, nhead, page_size), "float16")
page_table_indptr = T.match_buffer(var_page_table_indptr, (num_total_seqs_plus_1,), "int32")
page_table_values = T.match_buffer(var_page_table_values, (npage,), "int32")
values = T.match_buffer(var_values, (nlayer, nhead, seqlen), "float16")
for ax0_ax1_ax2_fused_0 in T.thread_binding((nlayer * nhead * seqlen + 1023) // 1024, thread="blockIdx.x"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("block"):
v0 = T.axis.spatial(nlayer, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % (seqlen * nhead * nlayer) // (seqlen * nhead))
v1 = T.axis.spatial(nhead, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % (seqlen * nhead) // seqlen)
v2 = T.axis.spatial(seqlen, (ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1) % seqlen)
T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < nlayer * nhead * seqlen)
T.reads(pages[page_table_values[page_table_indptr[seq_id] + v2 // page_size], v0, v1, v2 % page_size], page_table_values[page_table_indptr[seq_id] + v2 // page_size], page_table_indptr[seq_id])
T.writes(values[v0, v1, v2])
values[v0, v1, v2] = pages[page_table_values[page_table_indptr[seq_id] + v2 // page_size], v0, v1, v2 % page_size]
# fmt: on
target = Target("nvidia/geforce-rtx-3090-ti")
with target:
mod = tvm.IRModule({"main": func})
mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable
dl.gpu.Fallback(),
)(mod)
assert_structural_equal(mod["main"], expected)
def test_gpu_fallback_ignores_non_gpu_functions():
@I.ir_module
class Before:
# This function has no "target" attribute, and is scheduled
# using the `Target.current`.
@T.prim_func
def gpu_func(
A: T.Buffer((1, 32, 1, 128), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
B = T.alloc_buffer((1, 1, 32, 128), "float16")
for i, j, k, l in T.grid(1, 1, 32, 128):
with T.block("T_transpose"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
B[vi, vj, vk, vl] = A[vi, vk, vj, vl]
for i, j, k in T.grid(1, 1, 4096):
with T.block("T_reshape"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128]
# This function is identical, except that it is explicitly
# annotated with the "target" attribute, and is scheduled
# based on the annotation's target.
@T.prim_func
def cpu_func(
A: T.Buffer((1, 32, 1, 128), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
T.func_attr({"target": T.target("llvm")})
B = T.alloc_buffer((1, 1, 32, 128), "float16")
for i, j, k, l in T.grid(1, 1, 32, 128):
with T.block("T_transpose"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
B[vi, vj, vk, vl] = A[vi, vk, vj, vl]
for i, j, k in T.grid(1, 1, 4096):
with T.block("T_reshape"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128]
@I.ir_module
class After:
@T.prim_func
def gpu_func(
A: T.Buffer((1, 32, 1, 128), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
T.func_attr({"tir.is_scheduled": True})
for ax0_fused_0 in T.thread_binding(4, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("T_reshape"):
v0 = T.axis.spatial(4096, ax0_fused_0 * 1024 + ax0_fused_1)
T.reads(A[0, v0 // 128, 0, v0 % 128])
T.writes(C[0, 0, v0])
C[0, 0, v0] = A[0, v0 // 128, 0, v0 % 128]
@T.prim_func
def cpu_func(
A: T.Buffer((1, 32, 1, 128), "float16"),
C: T.Buffer((1, 1, 4096), "float16"),
):
T.func_attr({"target": T.target("llvm")})
B = T.alloc_buffer((1, 1, 32, 128), "float16")
for i, j, k, l in T.grid(1, 1, 32, 128):
with T.block("T_transpose"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
B[vi, vj, vk, vl] = A[vi, vk, vj, vl]
for i, j, k in T.grid(1, 1, 4096):
with T.block("T_reshape"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = B[0, 0, vk % 4096 // 128, vk % 128]
with Target("cuda"):
mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable
dl.gpu.Fallback(),
)(Before)
assert_structural_equal(mod, After)
if __name__ == "__main__":
tvm.testing.main()