blob: 41b0e714d1d02cb3271ae0330186a4fa7f85d1c2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
import pytest
import tvm
from tvm import relax
from tvm.ir.base import assert_structural_equal
from tvm.script.parser import relax as R, tir as T
def test_copy_with_new_vars():
@R.function
def before(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
gv = R.add(x, y)
return gv
after = relax.utils.copy_with_new_vars(before)
assert_structural_equal(after, before)
assert len(after.params) == len(before.params)
for before_var, after_var in zip(before.params, after.params):
assert before_var != after_var
def test_copy_with_new_vars_copied_symbolic_vars():
@R.function
def before(x: R.Tensor(("m",), "float32"), y: R.Tensor(("m",), "float32")):
gv = R.add(x, y)
return gv
after = relax.utils.copy_with_new_vars(before)
assert_structural_equal(after, before)
assert len(after.params) == len(before.params)
for before_var, after_var in zip(before.params, after.params):
assert before_var != after_var
assert before_var.struct_info.shape[0] != after_var.struct_info.shape[0]
def test_copy_with_new_vars_on_ir_module():
@tvm.script.ir_module
class Actual:
@R.function
def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
gv = R.add(x, y)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
gv = R.add(x, y)
return gv
@R.function
def func_copied(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
gv = R.add(x, y)
return gv
Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"]).with_attr(
"global_symbol", "func_copied"
)
# Assertion will fail if the f_copied contains the same VarNode that's used in
# the original function, due to var mapping during structural equal.
assert_structural_equal(Actual, Expected)
def test_copy_with_new_vars_on_ir_module_nested_function():
@tvm.script.ir_module
class Actual:
@R.function
def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
@R.function
def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"):
gv = R.add(x, x)
return gv
gv = R.add(x, y)
return gv
@tvm.script.ir_module
class Expected:
@R.function
def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
@R.function
def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"):
gv = R.add(x, x)
return gv
gv = R.add(x, y)
return gv
@R.function
def func_copied(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
@R.function
def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"):
gv = R.add(x, x)
return gv
gv = R.add(x, y)
return gv
Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"]).with_attr(
"global_symbol", "func_copied"
)
assert_structural_equal(Actual, Expected)
def test_assert_structural_equal_in_seqexpr():
"""The first mismatch is correctly identified."""
@R.function(private=True)
def func_1(A: R.Tensor([16, 16], "float32")):
B = R.concat([A, A])
return B
@R.function(private=True)
def func_2(A: R.Tensor([16, 16], "float32")):
B = R.add(A, A)
C = R.add(B, B)
return B
with pytest.raises(
ValueError,
match=re.escape("<root>.body.blocks[0].bindings[0].value.op"),
):
assert_structural_equal(func_1, func_2)
def test_structural_equal_of_call_nodes():
"""relax.Call must be compared by structural equality, not reference"""
# Three identical calls to relax.op.zeros
calls_to_op_zero = [relax.op.zeros([16], "int32") for _ in range(3)]
@R.function(private=True)
def uses_same_object_twice():
A = calls_to_op_zero[0]
B = calls_to_op_zero[0]
C = R.add(A, B)
return C
@R.function(private=True)
def uses_two_different_objects():
A = calls_to_op_zero[1]
B = calls_to_op_zero[2]
C = R.add(A, B)
return C
tvm.ir.assert_structural_equal(uses_same_object_twice, uses_two_different_objects)
def test_structural_equal_with_recursive_lambda_function():
"""A recursive lambda function may be checked for structural equality
Recursive function definitions may reference the bound variable
within the value being bound. In these cases, the `DefEqual(var,
other->var)` must occur first, to ensure it is defined at point of
use.
In all other cases, checking for structural equality of the bound
value prior to the variable provides a better error message.
"""
def define_function():
@R.function
def func(n: R.Prim("int64")):
@R.function
def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"):
i = T.int64()
if R.prim_value(i == 0):
output = R.prim_value(T.int64(0))
else:
remainder_relax = recursive_lambda(R.prim_value(i - 1))
remainder_tir = T.int64()
_ = R.match_cast(remainder_relax, R.Prim(value=remainder_tir))
output = R.prim_value(i + remainder_tir)
return output
return recursive_lambda(n)
return func
func_1 = define_function()
func_2 = define_function()
tvm.ir.assert_structural_equal(func_1, func_2)
def test_structural_equal_with_distinct_recursive_lambda_function():
"""A recursive lambda function may be checked for structural equality
Like `test_structural_equal_with_recursive_lambda_function`, but
comparing between two distinct functions.
"""
@R.function(private=True)
def func_a(n: R.Prim("int64")):
@R.function
def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"):
i = T.int64()
if R.prim_value(i == 0):
output = R.prim_value(T.int64(0))
# ^
# The first mismatch is here ^
else:
remainder_relax = recursive_lambda(R.prim_value(i - 1))
remainder_tir = T.int64()
_ = R.match_cast(remainder_relax, R.Prim(value=remainder_tir))
output = R.prim_value(i + remainder_tir)
return output
return recursive_lambda(n)
@R.function(private=True)
def func_b(n: R.Prim("int64")):
@R.function
def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"):
i = T.int64()
if R.prim_value(i == 0):
output = R.prim_value(T.int64(1))
# ^
# The first mismatch is here ^
else:
remainder_relax = recursive_lambda(R.prim_value(i - 1))
remainder_tir = T.int64()
_ = R.match_cast(remainder_relax, R.Prim(value=remainder_tir))
output = R.prim_value(i * remainder_tir)
return output
return recursive_lambda(n)
# The path to the first mismatch, which should appear within the
# error message.
mismatch_path = [
"<root>",
"body",
"blocks[0]",
"bindings[0]",
"value",
"body",
"blocks[0]",
"bindings[0]",
"value",
"true_branch",
"body",
"value",
"value",
]
with pytest.raises(ValueError, match=re.escape(".".join(mismatch_path))):
tvm.ir.assert_structural_equal(func_a, func_b)
if __name__ == "__main__":
pytest.main([__file__])