blob: 01c449a387311f30799b8a698b74f678c43af5b4 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for ExecContext (RFC v3 §6). Cases mirror RFC §8.1 -- §8.10."""
from __future__ import annotations
import pytest
from tvm.tirx.exec_context import (
CLUSTER,
CTA,
LANE_CTA_THREAD,
LANE_FLAT,
LANE_W_INNER,
LANE_WG_OUTER,
LANE_WG_THREAD,
THREAD,
WARP,
WARPGROUP,
AxisRange,
ExecContext,
ExecContextError,
LaneBinding,
filter_modulo,
filter_narrow,
initial_A,
scope_switch,
)
# -- canonical bindings declared at kernel entry (see RFC §8 naming conv) --
WARP_FLAT = LaneBinding(axis="warpid", kind=LANE_FLAT, declared_extent=16)
WG_OUTER = LaneBinding(axis="warpid", kind=LANE_WG_OUTER, declared_extent=4)
W_INNER = LaneBinding(axis="warpid", kind=LANE_W_INNER, declared_extent=4)
LANE_BIND = LaneBinding(axis="laneid", kind=LANE_FLAT, declared_extent=32)
CTA_BIND = LaneBinding(axis="cta_id", kind=LANE_FLAT, declared_extent=1)
CTA_THREAD_BIND = LaneBinding(axis="thread", kind=LANE_CTA_THREAD, declared_extent=256)
WG_THREAD_BIND = LaneBinding(axis="thread", kind=LANE_WG_THREAD, declared_extent=128)
# ---------------------------------------------------------------------------
# §3 scope_switch: split table
# ---------------------------------------------------------------------------
def test_initial_A_single_cta():
A = initial_A(warp_ext=16)
assert A.laneid == AxisRange(32, 0)
assert A.warpid == AxisRange(16, 0)
assert A.cta_id == AxisRange(1, 0)
assert A.size == 512
def test_initial_A_cluster():
A = initial_A(warp_ext=16, cta_ext=4)
assert A.cta_id == AxisRange(4, 0)
assert A.size == 2048
def test_axis_modulo_filter_uses_stride():
A = initial_A(warp_ext=16, cta_ext=4)
A = filter_modulo(A, "cta_id", 2, 0)
assert A.cta_id == AxisRange(2, 0, 2)
A = filter_narrow(A, CTA_BIND, 1, 4)
assert A.cta_id == AxisRange(1, 2, 2)
def test_axis_modulo_filter_two_cta_pair_residues():
A = initial_A(warp_ext=16, cta_ext=2)
assert filter_modulo(A, "cta_id", 2, 0).cta_id == AxisRange(1, 0, 2)
assert filter_modulo(A, "cta_id", 2, 1).cta_id == AxisRange(1, 1, 2)
@pytest.mark.parametrize(
"kappa,expected_inter_axes,expected_intra_axes",
[
(THREAD, {"laneid", "warpid", "cta_id"}, set()),
(WARP, {"warpid", "cta_id"}, {"laneid"}),
(CTA, {"cta_id"}, {"laneid", "warpid"}),
(CLUSTER, set(), {"laneid", "warpid", "cta_id"}),
],
)
def test_scope_switch_trivial(kappa, expected_inter_axes, expected_intra_axes):
A = initial_A(warp_ext=16, cta_ext=4)
split = scope_switch(A, kappa)
assert set(split.inter) == expected_inter_axes
assert set(split.intra) == expected_intra_axes
def test_scope_switch_warpgroup_aligned():
A = initial_A(warp_ext=16)
split = scope_switch(A, WARPGROUP)
assert split.inter["wgid"] == AxisRange(4, 0)
assert split.inter["cta_id"] == AxisRange(1, 0)
assert split.intra["laneid"] == AxisRange(32, 0)
assert split.intra["wid_in_wg"] == AxisRange(4, 0)
# ---------------------------------------------------------------------------
# §4.2 warpgroup factoring: 3 cases
# ---------------------------------------------------------------------------
def test_factor_case1_aligned():
A = initial_A(warp_ext=8) # ext=8, off=0 -- aligned
split = scope_switch(A, WARPGROUP)
assert split.inter["wgid"] == AxisRange(2, 0)
assert split.intra["wid_in_wg"] == AxisRange(4, 0)
def test_factor_case2_fits_in_one_wg():
# warpid ext=2, off=0 -- fits in one wg
A = initial_A(warp_ext=16)
A = filter_narrow(A, WARP_FLAT, 0, 2)
split = scope_switch(A, WARPGROUP)
assert split.inter["wgid"] == AxisRange(1, 0)
assert split.intra["wid_in_wg"] == AxisRange(2, 0)
def test_factor_case2_offset():
# warpid ext=2, off=6 -> wid_off=2, fits (2 <= 4-2)
A = initial_A(warp_ext=16)
A = filter_narrow(A, WARP_FLAT, 6, 8)
split = scope_switch(A, WARPGROUP)
assert split.inter["wgid"] == AxisRange(1, 1)
assert split.intra["wid_in_wg"] == AxisRange(2, 2)
def test_factor_case3_fails():
# RFC §8.6: warpid[2:6] crosses wg boundary unaligned
A = initial_A(warp_ext=16)
A = filter_narrow(A, WARP_FLAT, 2, 6)
assert A.warpid == AxisRange(4, 2)
with pytest.raises(ExecContextError, match="crosses warpgroup boundary"):
scope_switch(A, WARPGROUP)
# ---------------------------------------------------------------------------
# §8.1 -- Pure narrowing CTA -> WG -> W
# ---------------------------------------------------------------------------
def test_ex_8_1_cta_wg_warp():
ctx = ExecContext.at_kernel_entry(warp_ext=16)
# with T.cta()
ctx = ctx.with_scope_switch(CTA)
assert ctx.inter == {"cta_id": AxisRange(1, 0)}
assert ctx.intra == {"laneid": AxisRange(32, 0), "warpid": AxisRange(16, 0)}
# with T.warpgroup()
ctx = ctx.with_scope_switch(WARPGROUP)
assert ctx.inter == {"wgid": AxisRange(4, 0), "cta_id": AxisRange(1, 0)}
assert ctx.intra == {"laneid": AxisRange(32, 0), "wid_in_wg": AxisRange(4, 0)}
# with T.warp()
ctx = ctx.with_scope_switch(WARP)
assert ctx.inter == {"warpid": AxisRange(16, 0), "cta_id": AxisRange(1, 0)}
assert ctx.intra == {"laneid": AxisRange(32, 0)}
# ---------------------------------------------------------------------------
# §8.2 -- Filter + scope_switch
# ---------------------------------------------------------------------------
def test_ex_8_2_filter_then_warpgroup():
ctx = ExecContext.at_kernel_entry(warp_ext=16).with_scope_switch(CTA)
ctx = ctx.with_filter(WARP_FLAT, 0, 8)
assert ctx.A.warpid == AxisRange(8, 0)
# recompute at cta: intra=(lane:32, warp:8)
assert ctx.intra == {"laneid": AxisRange(32, 0), "warpid": AxisRange(8, 0)}
# enter warpgroup: factor(8, 0) -> case 1
ctx = ctx.with_scope_switch(WARPGROUP)
assert ctx.inter == {"wgid": AxisRange(2, 0), "cta_id": AxisRange(1, 0)}
assert ctx.intra == {"laneid": AxisRange(32, 0), "wid_in_wg": AxisRange(4, 0)}
# ---------------------------------------------------------------------------
# §8.3 -- Sugar form T.warp(warpid[2:4])
# ---------------------------------------------------------------------------
def test_ex_8_3_sugar_warp_range():
ctx = ExecContext.at_kernel_entry(warp_ext=16).with_scope_switch(CTA)
# desugar: filter warpid[2:4], then warp
ctx = ctx.with_filter(WARP_FLAT, 2, 4).with_scope_switch(WARP)
assert ctx.A.warpid == AxisRange(2, 2)
assert ctx.inter == {"warpid": AxisRange(2, 2), "cta_id": AxisRange(1, 0)}
assert ctx.intra == {"laneid": AxisRange(32, 0)}
# ---------------------------------------------------------------------------
# §8.4 -- Widen after filter (warp -> warpgroup)
# ---------------------------------------------------------------------------
def test_ex_8_4_widen_warp_to_wg():
ctx = ExecContext.at_kernel_entry(warp_ext=16).with_scope_switch(CTA)
ctx = ctx.with_filter(WARP_FLAT, 0, 4).with_scope_switch(WARP)
# widen to warpgroup
ctx = ctx.with_scope_switch(WARPGROUP)
assert ctx.inter == {"wgid": AxisRange(1, 0), "cta_id": AxisRange(1, 0)}
assert ctx.intra == {"laneid": AxisRange(32, 0), "wid_in_wg": AxisRange(4, 0)}
# ---------------------------------------------------------------------------
# §8.5 -- Partial warp selection -> warpgroup (partial intra)
# ---------------------------------------------------------------------------
def test_ex_8_5_partial_wg():
ctx = ExecContext.at_kernel_entry(warp_ext=16).with_scope_switch(CTA)
ctx = ctx.with_filter(WARP_FLAT, 0, 2).with_scope_switch(WARPGROUP)
# case 2: 2 <= 4-0
assert ctx.inter == {"wgid": AxisRange(1, 0), "cta_id": AxisRange(1, 0)}
assert ctx.intra == {"laneid": AxisRange(32, 0), "wid_in_wg": AxisRange(2, 0)}
# ---------------------------------------------------------------------------
# §8.6 -- Cross warpgroup boundary (factor fails)
# ---------------------------------------------------------------------------
def test_ex_8_6_factor_fail():
ctx = ExecContext.at_kernel_entry(warp_ext=16).with_scope_switch(CTA)
# with_filter recomputes (inter, intra) for current scope_kind=cta -- still OK
ctx2 = ctx.with_filter(WARP_FLAT, 2, 6)
assert ctx2.A.warpid == AxisRange(4, 2)
# scope_switch to warpgroup is the one that must fail
with pytest.raises(ExecContextError, match="crosses warpgroup boundary"):
ctx2.with_scope_switch(WARPGROUP)
# ---------------------------------------------------------------------------
# §8.7 -- Deep mixed nesting
# ---------------------------------------------------------------------------
def test_ex_8_7_deep_nested():
ctx = ExecContext.at_kernel_entry(warp_ext=16).with_scope_switch(CTA)
ctx = ctx.with_filter(WARP_FLAT, 0, 8).with_scope_switch(WARPGROUP)
assert ctx.inter == {"wgid": AxisRange(2, 0), "cta_id": AxisRange(1, 0)}
ctx = ctx.with_filter(WARP_FLAT, 0, 2)
# recompute at warpgroup: factor(2, 0) -> case 2
assert ctx.inter == {"wgid": AxisRange(1, 0), "cta_id": AxisRange(1, 0)}
assert ctx.intra == {"laneid": AxisRange(32, 0), "wid_in_wg": AxisRange(2, 0)}
ctx = ctx.with_scope_switch(WARP)
assert ctx.inter == {"warpid": AxisRange(2, 0), "cta_id": AxisRange(1, 0)}
assert ctx.intra == {"laneid": AxisRange(32, 0)}
ctx = ctx.with_filter(LANE_BIND, 0, 8)
assert ctx.intra == {"laneid": AxisRange(8, 0)}
assert ctx.inter == {"warpid": AxisRange(2, 0), "cta_id": AxisRange(1, 0)}
# ---------------------------------------------------------------------------
# §8.8 -- FA4 pattern: 3 sibling filter branches
# ---------------------------------------------------------------------------
def test_ex_8_8_fa4_pattern():
root = ExecContext.at_kernel_entry(warp_ext=16).with_scope_switch(CTA)
# Branch 1: warp 12 (single warp, tcgen05 MMA elected)
b1 = root.with_filter(WARP_FLAT, 12, 13)
assert b1.A.warpid == AxisRange(1, 12)
assert b1.intra == {"laneid": AxisRange(32, 0), "warpid": AxisRange(1, 12)}
# Branch 2: softmax warpgroups (warps 0-7)
b2 = root.with_filter(WARP_FLAT, 0, 8).with_scope_switch(WARPGROUP)
assert b2.inter == {"wgid": AxisRange(2, 0), "cta_id": AxisRange(1, 0)}
assert b2.intra == {"laneid": AxisRange(32, 0), "wid_in_wg": AxisRange(4, 0)}
# Branch 3: correction warpgroup (warps 8-11 = wg2)
b3 = root.with_filter(WARP_FLAT, 8, 12)
assert b3.A.warpid == AxisRange(4, 8)
assert b3.intra == {"laneid": AxisRange(32, 0), "warpid": AxisRange(4, 8)}
# And should factor cleanly when entering warpgroup
b3wg = b3.with_scope_switch(WARPGROUP)
assert b3wg.inter == {"wgid": AxisRange(1, 2), "cta_id": AxisRange(1, 0)}
assert b3wg.intra == {"laneid": AxisRange(32, 0), "wid_in_wg": AxisRange(4, 0)}
# ---------------------------------------------------------------------------
# §8.9 -- Cross-CTA with widening to cluster
# ---------------------------------------------------------------------------
def test_ex_8_9_cross_cta_cluster():
ctx = ExecContext.at_kernel_entry(warp_ext=16, cta_ext=4).with_scope_switch(CTA)
assert ctx.inter == {"cta_id": AxisRange(4, 0)}
# filter to warp 0, then warp
w = ctx.with_filter(WARP_FLAT, 0, 1).with_scope_switch(WARP)
assert w.inter == {"warpid": AxisRange(1, 0), "cta_id": AxisRange(4, 0)}
assert w.intra == {"laneid": AxisRange(32, 0)}
# back at cta scope, enter warpgroup
wg = ctx.with_scope_switch(WARPGROUP)
assert wg.inter == {"wgid": AxisRange(4, 0), "cta_id": AxisRange(4, 0)}
# widen to cluster
cl = wg.with_scope_switch(CLUSTER)
assert cl.inter == {}
assert cl.intra == {
"laneid": AxisRange(32, 0),
"warpid": AxisRange(16, 0),
"cta_id": AxisRange(4, 0),
}
# ---------------------------------------------------------------------------
# §8.10 -- identical to 8.3 modulo prose; covered above
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Rule 1 & 5: filter can only shrink A; saved/restored across scope exit
# (Restoration is the caller's (IR walker's) responsibility -- ExecContext
# is immutable, each with_filter returns a fresh ctx. Test that the parent
# is untouched.)
# ---------------------------------------------------------------------------
def test_filter_is_pure():
ctx = ExecContext.at_kernel_entry(warp_ext=16).with_scope_switch(CTA)
child = ctx.with_filter(WARP_FLAT, 0, 8)
assert ctx.A.warpid == AxisRange(16, 0) # parent not mutated
assert child.A.warpid == AxisRange(8, 0)
def test_filter_empty_range_rejected():
A = initial_A(warp_ext=16)
with pytest.raises(ExecContextError, match="empty or inverted"):
filter_narrow(A, WARP_FLAT, 5, 5)
def test_filter_out_of_range_rejected():
A = initial_A(warp_ext=16)
A = filter_narrow(A, WARP_FLAT, 0, 4)
with pytest.raises(ExecContextError, match="empty range"):
filter_narrow(A, WARP_FLAT, 8, 12) # disjoint from [0, 4)
def test_filter_flat_cta_thread_full_warp_range():
A = initial_A(warp_ext=8)
A = filter_narrow(A, CTA_THREAD_BIND, 0, 128)
assert A.warpid == AxisRange(4, 0)
assert A.laneid == AxisRange(32, 0)
def test_filter_flat_cta_thread_single_warp_lane_range():
A = initial_A(warp_ext=8)
A = filter_narrow(A, CTA_THREAD_BIND, 34, 40)
assert A.warpid == AxisRange(1, 1)
assert A.laneid == AxisRange(6, 2)
def test_filter_flat_cta_thread_nonrectangular_rejected():
A = initial_A(warp_ext=8)
with pytest.raises(ExecContextError, match="non-rectangular"):
filter_narrow(A, CTA_THREAD_BIND, 20, 50)
def test_filter_flat_warpgroup_thread_range_inside_one_warpgroup():
A = initial_A(warp_ext=8)
A = filter_narrow(A, WG_OUTER, 1, 2)
A = filter_narrow(A, WG_THREAD_BIND, 32, 64)
assert A.warpid == AxisRange(1, 5)
assert A.laneid == AxisRange(32, 0)
def test_filter_flat_warpgroup_thread_full_range_across_warpgroups_is_noop():
A = initial_A(warp_ext=8)
A2 = filter_narrow(A, WG_THREAD_BIND, 0, 128)
assert A2.warpid == AxisRange(8, 0)
assert A2.laneid == AxisRange(32, 0)
def test_filter_flat_warpgroup_thread_partial_range_across_warpgroups_rejected():
A = initial_A(warp_ext=8)
with pytest.raises(ExecContextError, match="multiple warpgroups"):
filter_narrow(A, WG_THREAD_BIND, 0, 64)
# ---------------------------------------------------------------------------
# Factor-lane bindings: wg_outer and w_inner
# ---------------------------------------------------------------------------
def test_filter_wg_outer():
A = initial_A(warp_ext=16)
A2 = filter_narrow(A, WG_OUTER, 1, 3) # wg 1..2 -> warps 4..11
assert A2.warpid == AxisRange(8, 4)
def test_filter_wg_outer_unaligned_rejected():
A = initial_A(warp_ext=16)
A = filter_narrow(A, WARP_FLAT, 2, 6) # warp offset 2 (not WG-aligned)
with pytest.raises(ExecContextError, match="aligned to WG_SIZE"):
filter_narrow(A, WG_OUTER, 0, 1)
def test_filter_w_inner():
A = initial_A(warp_ext=16)
# First narrow into a single warpgroup, then inner filter is valid
A = filter_narrow(A, WARP_FLAT, 4, 8) # wg1: warps 4..7
A2 = filter_narrow(A, W_INNER, 1, 3) # pick inner lanes 1..2
assert A2.warpid == AxisRange(2, 5)
def test_filter_w_inner_spanning_wg_rejected():
A = initial_A(warp_ext=16) # spans all 4 wgs
with pytest.raises(ExecContextError, match="spans multiple warpgroups"):
filter_narrow(A, W_INNER, 0, 2)
if __name__ == "__main__":
import sys
sys.exit(pytest.main([__file__, "-v"]))