blob: 621d38059a63c9c9d44d01780b606b2194bdfd2d [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for pass-time error enrichment with TVMScript-rendered locations.
A pass body that throws an error carrying a VisitErrorContext (e.g. a relax op
validator) is caught by the leaf pass executor and re-thrown with the failing
pass name plus the offending location rendered as underlined TVMScript.
"""
import pytest
import tvm
import tvm.testing
from tvm import relax
from tvm.ir import IRModule
def _bad_matmul_module():
"""Build (programmatically, no TVMScript parse) a module whose `main` binds a
matmul of incompatible shapes [3, 4] x [5, 6]. The function carries a
placeholder return struct info so it constructs; Normalize re-infers and the
matmul validator fires during the pass."""
x = relax.Var("x", relax.TensorStructInfo([3, 4], "float32"))
y = relax.Var("y", relax.TensorStructInfo([5, 6], "float32"))
lv = relax.Var("lv")
body = relax.SeqExpr([relax.BindingBlock([relax.VarBinding(lv, relax.op.matmul(x, y))])], lv)
func = relax.Function(
[x, y], body, ret_struct_info=relax.TensorStructInfo([3, 6], "float32"), is_pure=True
)
func = func.with_attr("global_symbol", "main")
return IRModule({relax.GlobalVar("main"): func})
@pytest.mark.skip_well_formed_check_before_transform
@pytest.mark.skip_well_formed_check_after_transform
def test_pass_error_renders_underlined_tvmscript():
"""End-to-end: a bad matmul through a function pass yields a message naming the
pass and an underlined TVMScript snippet of the offending binding."""
mod = _bad_matmul_module()
with pytest.raises(ValueError) as excinfo:
relax.transform.Normalize()(mod)
msg = str(excinfo.value)
# The original validator message is preserved.
assert "Matmul requires the reduction length" in msg
# The failing pass is named.
assert "Error in pass: Normalize" in msg
# The location is rendered as TVMScript with the offending expr underlined.
assert "Location (TVMScript):" in msg
assert "R.matmul(x, y" in msg
assert "^^^" in msg
@pytest.mark.skip_well_formed_check_before_transform
@pytest.mark.skip_well_formed_check_after_transform
def test_sequential_does_not_double_append():
"""Running the failing pass inside a Sequential must not enrich twice — the
Sequential wrapper does not guard, only the leaf pass does."""
mod = _bad_matmul_module()
seq = tvm.transform.Sequential([relax.transform.Normalize()])
with pytest.raises(ValueError) as excinfo:
seq(mod)
msg = str(excinfo.value)
assert msg.count("Location (TVMScript):") == 1
assert "Error in pass: Normalize" in msg
@pytest.mark.skip_well_formed_check_before_transform
@pytest.mark.skip_well_formed_check_after_transform
def test_error_without_resolvable_node_is_not_masked():
"""A pass that throws an error whose node is not findable in the module must
surface the original message without raising a printer/render error."""
@tvm.transform.module_pass(opt_level=0, name="ThrowUnresolvable")
class ThrowUnresolvable:
def transform_module(self, mod, ctx):
# A bare error with no VisitErrorContext payload -> nothing to resolve.
raise tvm.error.InternalError("deliberate failure with no resolvable location")
mod = _bad_matmul_module()
with pytest.raises(tvm.error.InternalError) as excinfo:
ThrowUnresolvable()(mod)
msg = str(excinfo.value)
assert "deliberate failure with no resolvable location" in msg
# No context => no location block appended, but also no crash.
assert "Location (TVMScript):" not in msg
if __name__ == "__main__":
tvm.testing.main()