blob: 0aadb260fa0f7036141b5f186d80223bfe031c9a [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for tvm.tirx.lang.alloc_pool validation."""
import pytest
from tvm.tirx.lang.alloc_pool import _validate_mma_alloc_shape
from tvm.tirx.operator.tile_primitive.cuda.tma_utils import SwizzleMode
# ---------------------------------------------------------------------------
# alloc_mma shape validation: bad inputs raise actionable ValueError instead of
# the opaque "Divide by zero" diagnostic that ``Layout.tile_to`` would emit.
# ---------------------------------------------------------------------------
class TestAllocMmaValidationRowBytes:
"""row width (cols * itemsize) must be a positive multiple of swizzle atom bytes."""
def test_bf16_32cols_128b_swizzle_too_narrow(self):
# The exact case that bit gdn-prefill v1_0 / v1_2 (eval R10).
# Row = 32 * 2B = 64B < 128B atom.
with pytest.raises(ValueError, match=r"64B rows.*128B swizzle atom"):
_validate_mma_alloc_shape((128, 32), "bfloat16", SwizzleMode.SWIZZLE_128B_ATOM)
def test_error_suggests_smaller_swizzle(self):
try:
_validate_mma_alloc_shape((128, 32), "bfloat16", SwizzleMode.SWIZZLE_128B_ATOM)
except ValueError as e:
assert "SWIZZLE_64B_ATOM" in str(e), f"missing fix-it hint: {e}"
else:
pytest.fail("should have raised")
def test_error_suggests_widening_cols(self):
try:
_validate_mma_alloc_shape((128, 32), "bfloat16", SwizzleMode.SWIZZLE_128B_ATOM)
except ValueError as e:
assert "multiple of 64 elements" in str(e), f"missing widen hint: {e}"
else:
pytest.fail("should have raised")
def test_fp32_16cols_128b_swizzle_too_narrow(self):
# Row = 16 * 4B = 64B < 128B atom.
with pytest.raises(ValueError, match=r"64B rows.*128B swizzle atom"):
_validate_mma_alloc_shape((128, 16), "float32", SwizzleMode.SWIZZLE_128B_ATOM)
def test_3d_shape_validates_last_dim(self):
# Validation must consider shape[-1], not shape[0].
with pytest.raises(ValueError, match=r"64B rows"):
_validate_mma_alloc_shape((2, 128, 32), "bfloat16", SwizzleMode.SWIZZLE_128B_ATOM)
class TestAllocMmaValidationRowCount:
"""rows (shape[-2]) must be a positive multiple of the 8-row atom."""
def test_rows_below_atom_rejected(self):
with pytest.raises(ValueError, match=r"shape\[-2\]=4.*multiple of 8"):
_validate_mma_alloc_shape((4, 64), "bfloat16", SwizzleMode.SWIZZLE_128B_ATOM)
def test_rows_not_multiple_of_8_rejected(self):
with pytest.raises(ValueError, match=r"shape\[-2\]=12.*multiple of 8"):
_validate_mma_alloc_shape((12, 64), "bfloat16", SwizzleMode.SWIZZLE_128B_ATOM)
class TestAllocMmaValidationRank:
"""rank-1 shapes cannot be tiled with a 2-D swizzle atom."""
def test_rank_one_rejected(self):
with pytest.raises(ValueError, match=r"fewer than 2 dimensions"):
_validate_mma_alloc_shape((128,), "bfloat16", SwizzleMode.SWIZZLE_128B_ATOM)
class TestAllocMmaValidationValid:
"""combinations that should succeed must not be rejected."""
@pytest.mark.parametrize(
"shape,dtype,mode",
[
# The fix path the agent should pick when row_bytes >= 128.
((128, 64), "bfloat16", SwizzleMode.SWIZZLE_128B_ATOM),
((128, 128), "bfloat16", SwizzleMode.SWIZZLE_128B_ATOM),
# Or downgrade to a swizzle whose atom matches the row.
((128, 32), "bfloat16", SwizzleMode.SWIZZLE_64B_ATOM),
((128, 16), "bfloat16", SwizzleMode.SWIZZLE_32B_ATOM),
# 3-D request validates the last two dims only.
((2, 128, 64), "bfloat16", SwizzleMode.SWIZZLE_128B_ATOM),
# fp32 with row width >= atom.
((128, 32), "float32", SwizzleMode.SWIZZLE_128B_ATOM),
# fp8 (1B) with row width >= atom.
((128, 128), "float8_e4m3", SwizzleMode.SWIZZLE_128B_ATOM),
],
)
def test_valid_combinations_accepted(self, shape, dtype, mode):
_validate_mma_alloc_shape(shape, dtype, mode)
def test_swizzle_none_skips_validation(self):
# SWIZZLE_NONE has no atom — even otherwise-bad shapes are allowed.
_validate_mma_alloc_shape((128, 32), "bfloat16", SwizzleMode.SWIZZLE_NONE)
_validate_mma_alloc_shape((3, 5), "bfloat16", SwizzleMode.SWIZZLE_NONE)
_validate_mma_alloc_shape((128,), "bfloat16", SwizzleMode.SWIZZLE_NONE)
if __name__ == "__main__":
pytest.main([__file__, "-v"])