blob: 5ed20e7162fe59473ea9d6ed51539faabbbce7c3 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from tvm.script import tirx as T
from tvm.script.tirx import tile as Tx
from tvm.tirx.analysis import verify_tirx_well_formed as verify
def test_root_scope():
# fmt: off
@T.prim_func(check_well_formed=False)
def test1() -> None:
T.device_entry()
pass
@T.prim_func(check_well_formed=False)
def test2() -> None:
pass
@T.prim_func(check_well_formed=False)
def test3() -> None:
pass
@T.prim_func(check_well_formed=False)
def test4() -> None:
T.device_entry()
pass
# fmt: on
verify(test1)
verify(test2)
verify(test3)
verify(test4)
def test_nested_scope():
# fmt: off
@T.prim_func(check_well_formed=False)
def test1() -> None:
T.device_entry()
pass
pass
@T.prim_func(check_well_formed=False)
def test2() -> None:
T.device_entry()
pass
@T.prim_func(check_well_formed=False)
def test3() -> None:
T.device_entry()
pass
@T.prim_func(check_well_formed=False)
def test4() -> None:
T.device_entry()
pass
pass
# fmt: on
verify(test1)
verify(test2)
verify(test3)
verify(test4)
def test_scope_id_consistency():
# fmt: off
@T.prim_func(check_well_formed=False)
def test1():
T.device_entry()
T.cta_id([32])
T.warp_id([4])
T.lane_id([32])
pass
@T.prim_func(check_well_formed=False)
def test2():
T.device_entry()
T.cta_id([32])
T.warp_id([4])
T.lane_id([32])
T.thread_id([128])
pass
@T.prim_func(check_well_formed=False)
def test3():
T.device_entry()
T.cta_id([32])
T.warp_id([2])
T.lane_id([32])
T.thread_id([128])
pass
@T.prim_func(check_well_formed=False)
def test4():
T.device_entry()
bx, by, bz = T.cta_id([8, 10, 12])
cbx, cby, cbz = T.cta_id_in_cluster([2, 2, 1])
clx, cly, clz = T.cluster_id([4, 5, 12])
T.evaluate(bx + by + bz)
T.evaluate(cbx + cby + cbz)
T.evaluate(clx + cly + clz)
@T.prim_func(check_well_formed=False)
def test5():
T.device_entry()
bx, by, bz = T.cta_id([8, 10, 12])
cbx, cby, cbz = T.cta_id_in_cluster([2, 2, 1])
clx, cly, clz = T.cluster_id([3, 5, 12])
T.evaluate(bx + by + bz)
T.evaluate(cbx + cby + cbz)
T.evaluate(clx + cly + clz)
@T.prim_func(check_well_formed=False)
def test6():
T.device_entry()
clx, cly, clz = T.cluster_id([4, 5, 12])
bx, by, bz = T.cta_id([8, 10, 12])
cbx, cby, cbz = T.cta_id_in_cluster([2, 2, 1])
T.evaluate(bx + by + bz)
T.evaluate(cbx + cby + cbz)
T.evaluate(clx + cly + clz)
@T.prim_func(check_well_formed=False)
def test7():
T.device_entry()
clx, cly, clz = T.cluster_id([3, 5, 12])
bx, by, bz = T.cta_id([8, 10, 12])
cbx, cby, cbz = T.cta_id_in_cluster([2, 2, 1])
T.evaluate(bx + by + bz)
T.evaluate(cbx + cby + cbz)
T.evaluate(clx + cly + clz)
# fmt: on
verify(test1)
verify(test2)
with pytest.raises(Exception, match="Inconsistent extents for scope"):
verify(test3)
verify(test4)
with pytest.raises(Exception, match="Inconsistent extents|non-divisible extents"):
verify(test5)
verify(test6)
with pytest.raises(Exception, match="Inconsistent extents|non-divisible extents"):
verify(test7)
def test_layout():
### TileLayout
# fmt: off
@T.prim_func(check_well_formed=False)
def test1():
T.device_entry()
T.cta_id([32])
T.warp_id([4])
T.lane_id([32])
A = T.alloc_buffer((2,), layout=T.TileLayout(T.S[2, 1]))
A[0] = 0
# fmt: on
verify(test1)
### SwizzleLayout
# fmt: off
@T.prim_func(check_well_formed=False)
def test2():
T.device_entry()
T.cta_id([32])
T.warp_id([4])
T.lane_id([32])
A = T.alloc_buffer((512,), scope="shared", layout=T.SwizzleLayout(3, 3, 3))
A[0] = 0
# fmt: on
verify(test2)
def test_host():
# fmt: off
@T.prim_func(check_well_formed=False)
def test1(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1)
T.call_packed("runtime.cuTensorMapEncodeTiled", A_map, "float32", 2, A.data, 16, 16, 64, 16, 16, 1, 1, 0, 0, 0, 0) # noqa: E501
T.device_entry()
for blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for threadIdx in T.thread_binding(128, thread="threadIdx.x"):
bar = T.alloc_buffer((1,), "uint64", scope="shared", align=8)
phase = T.alloc_buffer((1,), "int32", scope="local")
A_smem = T.alloc_buffer((16, 16), "float32", scope="shared", align=128)
phase[0] = 0
if threadIdx == 0:
T.ptx.mbarrier.init(bar.data, 1)
T.ptx.fence.proxy_async("shared::cta")
T.ptx.cp_async.bulk.tensor.g2c(2, A_smem.data, bar.data, T.address_of(A_map), 0, 1, "", 0, 0) # noqa: E501
T.ptx.mbarrier.arrive.expect_tx(bar.data, 16*16*4)
T.ptx.mbarrier.try_wait(bar.data, phase[0])
phase[0] = phase[0] ^ 1
T.print_buffer(A_smem.data, "float32", False, False, 2, 16*16)
# fmt: on
verify(test1)
def test_device_func():
# Per-call exec-scope migration: scope is now attached per op via the
# ``T.op[scope](...)`` subscription surface instead of a ``with T.cta():``
# region. ``test1`` exercises a per-call-scoped op; ``test2`` the plain
# (unscoped) op. The old multi-root-scope negative case asserted the removed
# "only one root scope" verifier rule and no longer has an equivalent, so it
# is dropped.
# fmt: off
@T.prim_func(check_well_formed=False)
def test1(A: T.Buffer((128,), "float32")):
T.device_entry()
T.cta_id([1])
T.thread_id([128])
Tx.cta.fill(A, 0.)
@T.prim_func(check_well_formed=False)
def test2(A: T.Buffer((128,), "float32")):
T.device_entry()
T.cta_id([128])
T.thread_id([128])
Tx.fill(A, 0.)
# fmt: on
verify(test1, device_func=True)
verify(test2, device_func=True)
def test_preferred_cluster_validation():
# fmt: off
# Valid: cluster→cta with preferred_extents matching size
@T.prim_func(check_well_formed=False)
def test1() -> None:
T.device_entry()
cbx, cby = T.cta_id_in_cluster([2, 1], preferred=[2, 2])
tx = T.thread_id([128])
T.evaluate(cbx + cby + tx)
# Invalid: preferred size doesn't match extents size (caught at verify time)
@T.prim_func(check_well_formed=False)
def test2() -> None:
T.device_entry()
cbx, cby = T.cta_id_in_cluster([2, 1], preferred=[2])
tx = T.thread_id([128])
T.evaluate(cbx + cby + tx)
# fmt: on
verify(test1)
with pytest.raises(Exception, match="preferred_extents must have the same size"):
verify(test2)
# Invalid: preferred on a non-cluster→cta scope (caught at IR build time)
with pytest.raises(Exception):
# fmt: off
@T.prim_func(check_well_formed=False)
def test3() -> None:
T.device_entry()
bx = T.cta_id([128], preferred=[256])
tx = T.thread_id([128])
T.evaluate(bx + tx)
# fmt: on
def test_scope_id_deferred_relaxed_at_construction():
"""Deferred scope_id (no extents) must pass the well-formed check even when
no sibling provides enough info to resolve it -- strict resolution is
deferred to LowerTIRx."""
# fmt: off
@T.prim_func(check_well_formed=False)
def partial_only_cta():
T.device_entry()
bx = T.cta_id() # deferred kernel→cta, no closure source
tx = T.thread_id([128]) # explicit
T.evaluate(bx + tx)
@T.prim_func(check_well_formed=False)
def all_deferred():
T.device_entry()
bx = T.cta_id()
wg = T.warpgroup_id()
warp = T.warp_id_in_wg()
lane = T.lane_id()
T.evaluate(bx + wg + warp + lane)
@T.prim_func(check_well_formed=False)
def mixed():
T.device_entry()
# kCtaWarp=4, kWarpThread=32 → kCtaThread=128 derivable.
T.warp_id([4])
T.lane_id([32])
T.thread_id() # deferred kCtaThread, resolvable via closure
pass
# fmt: on
# All three accepted by well-formed: deferred extents are tolerated.
verify(partial_only_cta)
verify(all_deferred)
verify(mixed)
def test_scope_id_deferred_consistency_still_enforced():
"""Even with deferred defs, known-known consistency between sibling defs
must still be enforced by the closure check."""
# fmt: off
@T.prim_func(check_well_formed=False)
def inconsistent():
# 4 warps * 32 lanes = 128 threads, but explicit thread_id says 64 -> error.
T.device_entry()
T.cta_id([32])
T.warp_id([4])
T.lane_id([32])
T.thread_id() # deferred (shouldn't shadow the conflict)
T.thread_id([64]) # conflicts with derived kCtaThread=128
pass
# fmt: on
with pytest.raises(Exception, match="Inconsistent extents for scope"):
verify(inconsistent)
def test_scope_id_deferred_multi_var_rejected():
"""Deferred form (no extents) requires exactly one Var. Multi-var defers
have no well-defined recovery from fused closure values."""
# The C++ ScopeIdDef ctor enforces this; constructing such a def from the
# parser path is not currently expressible (parser only emits single-Var
# deferred), but we exercise the FFI-level guard directly.
from tvm.tirx.exec_scope import ScopeIdDef
from tvm.tirx.expr import Var
# Single-Var deferred form is fine.
ScopeIdDef([Var("", "int32")], None, "kernel", "cta")
# Two-Var deferred should be rejected.
with pytest.raises(Exception, match="Deferred ScopeIdDef.*must define exactly one Var"):
ScopeIdDef([Var("", "int32"), Var("", "int32")], None, "kernel", "cta")
if __name__ == "__main__":
test_root_scope()
test_nested_scope()
test_scope_id_consistency()
test_layout()
test_host()
test_device_func()
test_scope_id_deferred_relaxed_at_construction()
test_scope_id_deferred_consistency_still_enforced()
test_scope_id_deferred_multi_var_rejected()