blob: 75fbaccb7fb9f7e091da2db83302c4240c3add21 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for tvm.tirx.bench utilities."""
import pytest
import torch
import tvm.testing
from tvm.tirx.bench import _compute_group_count, _parse_proton_tree, bench, tensor_bytes
# ── _parse_proton_tree ──────────────────────────────────────────────────────
SAMPLE_TREE = """\
├─ 1.500 tir
│ ├─ 1.500 my_kernel_fn
│ └─ 0.001 vectorized_elementwise_kernel
└─ 0.800 cublas
└─ 0.800 sm90_xmma_gemm_f16f16
"""
def test_parse_proton_tree_basic():
impls, errors = _parse_proton_tree(SAMPLE_TREE)
assert impls == {"tir": 1.5, "cublas": 0.8}
assert errors == {}
def test_parse_proton_tree_filters_elementwise():
"""vectorized_elementwise_kernel and elementwise_kernel_with_index are skipped."""
tree = """\
├─ 0.500 tir
│ ├─ 0.500 real_kernel
│ └─ 0.001 elementwise_kernel_with_index
"""
impls, _ = _parse_proton_tree(tree)
assert impls == {"tir": 0.5}
def test_parse_proton_tree_slowest_child():
"""Takes the slowest depth-2 child per impl."""
tree = """\
├─ 2.000 tir
│ ├─ 0.300 kernel_a
│ └─ 0.700 kernel_b
"""
impls, _ = _parse_proton_tree(tree)
assert impls == {"tir": 0.7}
def test_parse_proton_tree_baseline_errors():
tree = """\
BASELINE_ERROR: cublas: CUDA OOM
├─ 1.000 tir
│ └─ 1.000 my_kernel
"""
impls, errors = _parse_proton_tree(tree)
assert impls == {"tir": 1.0}
assert errors == {"cublas": "CUDA OOM"}
def test_parse_proton_tree_ansi_stripped():
"""ANSI color codes are stripped before parsing."""
tree = "\x1b[32m├─ 1.000 tir\x1b[0m\n│ └─ 1.000 k\n"
impls, _ = _parse_proton_tree(tree)
assert impls == {"tir": 1.0}
def test_parse_proton_tree_empty():
impls, errors = _parse_proton_tree("")
assert impls == {}
assert errors == {}
# ── bench ───────────────────────────────────────────────────────────────────
@tvm.testing.requires_cuda
def test_bench_basic():
"""bench returns positive times for each impl."""
M, N = 256, 256
funcs = {"matmul": lambda case: torch.mm(case[0], case[1])}
def make_input():
A = torch.randn(M, N, device="cuda", dtype=torch.float16)
B = torch.randn(M, N, device="cuda", dtype=torch.float16)
return (A, B), tensor_bytes(A, B)
results = bench(funcs, make_input, warmup=5, repeat=10, cooldown_s=0.0, timer="event")
assert "matmul" in results["impls"]
assert results["impls"]["matmul"] > 0
@tvm.testing.requires_cuda
def test_bench_multiple_impls():
"""Multiple impls each get their own timing."""
M, N = 128, 128
funcs = {
"mm": lambda case: torch.mm(case[0], case[1]),
"addmm": lambda case: torch.addmm(
torch.zeros(M, N, device="cuda", dtype=torch.float16), case[0], case[1]
),
}
def make_input():
A = torch.randn(M, N, device="cuda", dtype=torch.float16)
B = torch.randn(M, N, device="cuda", dtype=torch.float16)
return (A, B), tensor_bytes(A, B)
results = bench(funcs, make_input, warmup=5, repeat=10, cooldown_s=0.0, timer="event")
assert set(results["impls"].keys()) == {"mm", "addmm"}
assert all(v > 0 for v in results["impls"].values())
@tvm.testing.requires_cuda
def test_bench_multiple_input_groups():
"""Multiple input groups cycle correctly (L2 eviction)."""
M, N = 128, 128
call_count = [0]
def make_input():
call_count[0] += 1
A = torch.randn(M, N, device="cuda", dtype=torch.float16)
B = torch.randn(M, N, device="cuda", dtype=torch.float16)
return (A, B), tensor_bytes(A, B)
funcs = {"mm": lambda case: torch.mm(case[0], case[1])}
results = bench(
funcs, make_input, warmup=5, repeat=20, cooldown_s=0.0, timer="event", l2_bytes=64 * 1024
)
assert results["impls"]["mm"] > 0
assert call_count[0] > 1
# ── _compute_group_count ───────────────────────────────────────────────────
def test_compute_groups_small_tensors():
"""Small tensors need many groups to fill 3x L2."""
# 128x128 fp16 = 32KB. 3*128MB / 32KB = 12288, +1 = 12289
input_bytes = tensor_bytes(torch.empty(128, 128, dtype=torch.float16))
n = _compute_group_count(input_bytes, l2_bytes=128 * 1024 * 1024)
assert n == 12289
def test_compute_groups_large_tensors():
"""Inputs >= 3x L2 need only 1 group."""
# 16384x16384 fp32 = 1GB >> 3*128MB = 384MB
input_bytes = tensor_bytes(torch.empty(16384, 16384, dtype=torch.float32))
n = _compute_group_count(input_bytes, l2_bytes=128 * 1024 * 1024)
assert n == 1
def test_compute_groups_moderate_tensors():
"""Moderate tensors: floor(3*L2 / input) + 1."""
# 8192x8192 bf16 = 128MB. floor(384M / 128M) + 1 = 4
input_bytes = tensor_bytes(torch.empty(8192, 8192, dtype=torch.bfloat16))
n = _compute_group_count(input_bytes, l2_bytes=128 * 1024 * 1024)
assert n == 4
@tvm.testing.requires_cuda
def test_bench_legacy_callable_api():
"""bench still accepts the existing single-callable API used by TIRx tests."""
M, N = 128, 128
A = torch.randn(M, N, device="cuda", dtype=torch.float16)
B = torch.randn(M, N, device="cuda", dtype=torch.float16)
result = bench(
lambda: torch.mm(A, B), warmup=1, repeat=2, proton_name="legacy", flush_l2_size=1
)
assert result > 0
@tvm.testing.requires_cuda
def test_bench_callable_inputs():
"""bench accepts a factory callable and auto-computes groups."""
M, N = 256, 256
call_count = [0]
def make_input():
call_count[0] += 1
case = (
torch.randn(M, N, device="cuda", dtype=torch.float16),
torch.randn(M, N, device="cuda", dtype=torch.float16),
)
return case, tensor_bytes(*case)
funcs = {"mm": lambda case: torch.mm(case[0], case[1])}
results = bench(funcs, make_input, warmup=5, repeat=10, cooldown_s=0.0, timer="event")
assert "mm" in results["impls"]
assert results["impls"]["mm"] > 0
assert call_count[0] >= 2 # at least 2 groups created
if __name__ == "__main__":
pytest.main([__file__, "-v"])