blob: 0c50d4bb222fd5d1ccd368d35fcfd7913fc188aa [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""AssertStmt codegen tests: verify kind and message_parts produce correct exceptions."""
import pytest
import tvm
import tvm.testing
from tvm.script import tirx as T
codegen_target = tvm.testing.parameter("llvm", "c")
def test_assert_runtime_error(codegen_target):
"""AssertStmt with RuntimeError kind produces RuntimeError."""
@T.prim_func
def func(x: T.int32):
assert x > 0, ("RuntimeError", ["Expected non-null input"])
lib = tvm.compile(func, target=codegen_target)
with pytest.raises(RuntimeError, match="Expected non-null input"):
lib(0)
def test_assert_value_error(codegen_target):
"""AssertStmt with ValueError kind produces ValueError."""
@T.prim_func
def func(x: T.int32):
assert x > 0, ("ValueError", ["Shape mismatch: expected 4 got 8"])
lib = tvm.compile(func, target=codegen_target)
with pytest.raises(ValueError, match="Shape mismatch"):
lib(0)
def test_assert_type_error(codegen_target):
"""AssertStmt with TypeError kind produces TypeError."""
@T.prim_func
def func(x: T.int32):
assert x > 0, ("TypeError", ["Expected Tensor but got int"])
lib = tvm.compile(func, target=codegen_target)
with pytest.raises(TypeError, match="Expected Tensor but got int"):
lib(0)
def test_assert_multi_part_message(codegen_target):
"""Multi-part messages are correctly concatenated at runtime."""
@T.prim_func
def func(x: T.int32):
assert x > 0, ("ValueError", ["Expected shape ", "4", " but got ", "8"])
lib = tvm.compile(func, target=codegen_target)
with pytest.raises(ValueError, match="Expected shape 4 but got 8"):
lib(0)
def test_assert_passing_condition(codegen_target):
"""Passing assertion does not raise."""
@T.prim_func
def func(x: T.int32):
assert x > 0, ("RuntimeError", ["This should not be raised"])
lib = tvm.compile(func, target=codegen_target)
lib(1) # should pass without error
def test_assert_many_parts(codegen_target):
"""Assertion with 8 parts concatenated correctly."""
@T.prim_func
def func(x: T.int32):
assert x > 0, ("RuntimeError", ["p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7"])
lib = tvm.compile(func, target=codegen_target)
with pytest.raises(RuntimeError, match="p0p1p2p3p4p5p6p7"):
lib(0)
def test_tvmscript_assert_preserves_kind(codegen_target):
"""Regression: TVMScript structured assert preserves kind at runtime."""
@T.prim_func
def func(x: T.int32):
assert x > 0, ("ValueError", ["x must be positive"])
lib = tvm.compile(func, target=codegen_target)
with pytest.raises(ValueError, match="x must be positive"):
lib(0)
def test_tvmscript_assert_preserves_parts(codegen_target):
"""Regression: TVMScript structured assert with separate parts."""
@T.prim_func
def func(x: T.int32):
assert x > 0, ("ValueError", ["x must be ", "positive"])
lib = tvm.compile(func, target=codegen_target)
with pytest.raises(ValueError, match="x must be positive"):
lib(0)
if __name__ == "__main__":
tvm.testing.main()