blob: 5c587354cc3f8d93f1488907c6de9f597203ee99 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
from tvm.ir import assert_structural_equal
from tvm.relay.op.transform import split
from tvm.runtime import ObjectPath
from tvm.script import ir as I, tir as T
def _error_message(exception):
splitter = "ValueError: StructuralEqual"
return splitter + str(exception).split(splitter)[1]
def _expected_result(func1, func2, objpath1, objpath2):
return f"""ValueError: StructuralEqual check failed, caused by lhs at {objpath1}:
{func1.script(path_to_underline=[objpath1], syntax_sugar=False)}
and rhs at {objpath2}:
{func2.script(path_to_underline=[objpath2], syntax_sugar=False)}"""
def test_prim_func_buffer_map():
@T.prim_func
def func1(a: T.handle, b: T.handle):
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
@T.prim_func
def func2(a: T.handle, b: T.handle):
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 256))
with pytest.raises(ValueError) as ve:
assert_structural_equal(func1, func2)
assert _error_message(ve.value) == _expected_result(
func1,
func2,
ObjectPath.root()
.attr("buffer_map")
.map_value(func1.params[1])
.attr("shape")
.array_index(1)
.attr("value"),
ObjectPath.root()
.attr("buffer_map")
.map_value(func2.params[1])
.attr("shape")
.array_index(1)
.attr("value"),
)
def test_evaluate():
@I.ir_module
class module1:
@T.prim_func
def func():
T.evaluate(0)
@I.ir_module
class module2:
@T.prim_func
def func():
T.evaluate(1)
with pytest.raises(ValueError) as ve:
assert_structural_equal(module1, module2)
assert _error_message(ve.value) == _expected_result(
module1,
module2,
ObjectPath.root()
.attr("functions")
.map_value(module1.get_global_var("func"))
.attr("body")
.attr("value")
.attr("value"),
ObjectPath.root()
.attr("functions")
.map_value(module2.get_global_var("func"))
.attr("body")
.attr("value")
.attr("value"),
)
def test_allocate():
@T.prim_func
def func1():
a_data = T.allocate((128, 128), dtype="float32")
a = T.decl_buffer((128, 128), dtype="float32", data=a_data)
@T.prim_func
def func2():
a_data = T.allocate((256, 128), dtype="float32")
a = T.decl_buffer((256, 128), dtype="float32", data=a_data)
with pytest.raises(ValueError) as ve:
assert_structural_equal(func1, func2)
assert _error_message(ve.value) == _expected_result(
func1,
func2,
ObjectPath.root().attr("body").attr("extents").array_index(0).attr("value"),
ObjectPath.root().attr("body").attr("extents").array_index(0).attr("value"),
)
def test_for():
@T.prim_func
def func1():
for i, j in T.grid(128, 128):
with T.block():
pass
@T.prim_func
def func2():
for i, j, k in T.grid(128, 128, 128):
with T.block():
pass
with pytest.raises(ValueError) as ve:
assert_structural_equal(func1, func2)
assert _error_message(ve.value) == _expected_result(
func1,
func2,
ObjectPath.root().attr("body").attr("block").attr("body").attr("body").attr("body"),
ObjectPath.root().attr("body").attr("block").attr("body").attr("body").attr("body"),
)
if __name__ == "__main__":
tvm.testing.main()