blob: b95a0695a5853b3d6401c4ddaaec44970d4f89d0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
import tempfile
import pytest
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.local_rpc import LocalRPC
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
from tvm.dlight.benchmark import (
benchmark,
benchmark_prim_func,
benchmark_relax_func,
extract_prim_func,
extract_from_relax,
extract_func_info_from_prim_func,
)
import tvm.testing
# The test function uses an undefined symbolic var in Relax.
# In principle, this should be attached to an argument.
# pylint: disable=no-self-argument,invalid-name,line-too-long,no-method-argument
# fmt: off
@I.ir_module(check_well_formed=False)
class Module:
@T.prim_func
def full1(var_T_full: T.handle):
T.func_attr({"op_pattern": 0, "tir.noalias": True})
n = T.int64()
T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(32), T.int64(1), n), "float16")
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
with T.block("T_full"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads()
T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3])
T_full[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(1.0)
@T.prim_func
def full2(var_T_full: T.handle):
T.func_attr({"op_pattern": 0, "tir.noalias": True})
n = T.int64()
T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, T.int64(128)):
with T.block("T_full"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads()
T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3])
T_full[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(1.0)
@T.prim_func
def matmul1(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")):
T.func_attr({"op_pattern": 4, "tir.noalias": True})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16")
B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
# with T.block("root"):
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n):
with T.block("matmul"):
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
with T.init():
matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
@R.function
def test():
n = T.int64()
R.func_attr({"tir_var_upper_bound": {"n": 2048}})
cls = Module
with R.dataflow():
lv1 = R.call_tir(cls.full1,(), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv1_1 = R.call_tir(cls.full1,(), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv1_2 = R.call_tir(cls.full1,(), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2 = R.call_tir(cls.full2,(), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2_1 = R.call_tir(cls.full2,(), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv3 = R.call_tir(cls.matmul1, (lv1, lv2), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
R.output(lv3)
return lv3
@T.prim_func
def cuda_workload(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_matmul: T.handle):
T.func_attr({"tir.is_scheduled": True})
m = T.int64()
inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096)))
matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096)))
# with T.block("root"):
matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local")
inp0_reindex_pad_shared = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="shared")
inp1_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), scope="shared")
for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
for ax1_0 in T.thread_binding((m + T.int64(31)) // T.int64(32), thread="blockIdx.x"):
for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.y"):
for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"):
for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax2_3_init, ax1_3_init in T.grid(T.int64(4), T.int64(4)):
with T.block("matmul_init"):
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
T.reads()
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
for ax3_0 in range(T.int64(256)):
for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(T.int64(2)):
for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
with T.block("inp0_reindex_pad_shared"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
T.reads(inp0[v0, v1, v2])
T.writes(inp0_reindex_pad_shared[v0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0))
for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(T.int64(4)):
for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)):
with T.block("inp1_reindex_shared"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
v2 = T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
T.reads(inp1[v2, v1])
T.writes(inp1_reindex_shared[v0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
for ax3_1, ax2_3, ax1_3 in T.grid(T.int64(16), T.int64(4), T.int64(4)):
with T.block("matmul_update"):
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], inp0_reindex_pad_shared[T.int64(0), v1, v3], inp1_reindex_shared[T.int64(0), v2, v3])
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
matmul_reindex_pad_local[T.int64(0), v1, v2] = matmul_reindex_pad_local[T.int64(0), v1, v2] + inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0), v2, v3]
for ax0_1, ax1, ax2_0_1 in T.grid(T.int64(1), T.int64(4), T.int64(2)):
for ax2_1_1 in T.vectorized(T.int64(2)):
with T.block("matmul_reindex_pad_local"):
v0 = T.axis.spatial(T.int64(1), ax0_1)
v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2_0_1 * T.int64(2) + ax2_1_1)
T.reads(matmul_reindex_pad_local[v0, v1, v2])
T.writes(matmul[T.int64(0), v1, v2])
if v1 < m:
matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
# fmt: on
# pylint: enable=no-self-argument,invalid-name,line-too-long,no-method-argument
@pytest.mark.skip("requires CUDA")
def test_benchmark_prim_func_rpc():
with LocalRPC() as rpc:
rpc_config = ms.runner.RPCConfig(
tracker_host=rpc.tracker_host,
tracker_port=rpc.tracker_port,
tracker_key=rpc.tracker_key,
session_priority=1,
session_timeout_sec=100,
)
input_infos, _, _ = benchmark(
cuda_workload,
args=[
((1, "m", 4096), "float32"),
((4096, 4096), "float32"),
((1, "m", 4096), "float32"),
],
dym_var_sample={"m": 128},
target="nvidia/geforce-rtx-3070",
rpc_config=rpc_config,
)
assert input_infos == [
((1, 128, 4096), "float32"),
((4096, 4096), "float32"),
((1, 128, 4096), "float32"),
]
@pytest.mark.skip("requires CUDA")
def test_benchmark_prim_func_local():
input_infos, _, _ = benchmark(
cuda_workload,
args=[
((1, "m", 4096), "float32"),
((4096, 4096), "float32"),
((1, "m", 4096), "float32"),
],
dym_var_sample={"m": 128},
target="nvidia/geforce-rtx-3070",
)
assert input_infos == [
((1, 128, 4096), "float32"),
((4096, 4096), "float32"),
((1, 128, 4096), "float32"),
]
@pytest.mark.skip("requires CUDA")
def test_benchmark_prim_func_full_local():
with tvm.target.Target("nvidia/geforce-rtx-3070"):
benchmark_prim_func(
cuda_workload,
)
@pytest.mark.skip("requires CUDA")
def test_benchmark_prim_func_full_rpc():
with LocalRPC() as rpc:
rpc_config = ms.runner.RPCConfig(
tracker_host=rpc.tracker_host,
tracker_port=rpc.tracker_port,
tracker_key=rpc.tracker_key,
session_priority=1,
session_timeout_sec=100,
)
benchmark_prim_func(
cuda_workload,
target="nvidia/geforce-rtx-3070",
rpc_config=rpc_config,
evaluator_config=ms.runner.EvaluatorConfig(
number=10,
repeat=10,
min_repeat_ms=0,
enable_cpu_cache_flush=False,
),
)
def test_benchmark_relax_func():
with tvm.target.Target("llvm -num-cores=4"):
benchmark_relax_func(Module, "test")
def test_extract_prim_func_full1():
print(
extract_prim_func(
model_name="TEST",
relax_func_name="test",
prim_func_name="full1",
func=Module["full1"], # type: ignore
func_args=[((1, 32, 1, "n"), "float16")],
dym_var_dict={"n": "int32"},
weight=2,
sample_number=10,
target="llvm -num-cores=4",
)
)
def test_extract_prim_func_matmul1():
print(
extract_prim_func(
model_name="TEST",
relax_func_name="test",
prim_func_name="matmul1",
func=Module["matmul1"], # type: ignore
weight=2,
sample_number=10,
target="llvm -num-cores=4",
)
)
def test_extract_from_relax():
with tvm.target.Target("llvm -num-cores=4"):
with tempfile.TemporaryDirectory() as filepath:
extract_from_relax(
Module,
"TEST",
file_path=filepath,
)
def test_extract_func_info_from_prim_func():
assert (
str(extract_func_info_from_prim_func(cuda_workload))
== "([((1, m, 4096), 'float32'), ((4096, 4096), 'float32'), ((1, m, 4096), 'float32')], {'m': 'int64'})"
)
assert (
str(extract_func_info_from_prim_func(Module["full1"]))
== "([((1, 32, 1, n), 'float16')], {'n': 'int64'})"
)
assert (
str(extract_func_info_from_prim_func(Module["matmul1"]))
== "([((1, 32, 1, n), 'float16'), ((1, 32, n, 128), 'float16'), ((1, 32, 1, 128), 'float16')], {'n': 'int64'})"
)
assert (
str(extract_func_info_from_prim_func(Module["full2"]))
== "([((1, 32, n, 128), 'float16')], {'n': 'int64'})"
)
if __name__ == "__main__":
tvm.testing.main()