blob: 01af60724cbb55976838006a8d45d8f772d464d1 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
import pytest
from tvm import te
from tvm_ffi.access_path import AccessPath
from tvm.script import tir as T, ir as I
def consistent_equal(x, y, map_free_vars=False):
struct_equal0 = tvm.ir.structural_equal(x, y, map_free_vars)
struct_equal1 = tvm.ir.structural_equal(y, x, map_free_vars)
xhash = tvm.ir.structural_hash(x, map_free_vars)
yhash = tvm.ir.structural_hash(y, map_free_vars)
if struct_equal0 != struct_equal1:
raise ValueError(
"Non-commutative {} vs {}, sequal0={}, sequal1={}".format(
x, y, struct_equal0, struct_equal1
)
)
# NOTE: hash colision can happen but should be rare.
# we can confirm that hash colison doesn't happen for our testcases
if struct_equal0 != (xhash == yhash):
raise ValueError(
"Inconsistent {} vs {}, sequal={}, xhash={}, yhash={}".format(
x, y, struct_equal0, xhash, yhash
)
)
return struct_equal0
def get_sequal_mismatch(x, y, map_free_vars=False):
mismatch_0 = tvm.ir.base.get_first_structural_mismatch(x, y, map_free_vars)
mismatch_1 = tvm.ir.base.get_first_structural_mismatch(y, x, map_free_vars)
if mismatch_0 is None and mismatch_1 is None:
return None
if (
mismatch_0 is None
or mismatch_1 is None
or mismatch_0[0] != mismatch_1[1]
or mismatch_0[1] != mismatch_1[0]
):
raise ValueError(
"Non-commutative {} vs {}, mismatch_0={}, mismatch_1={}".format(
x, y, mismatch_0, mismatch_1
)
)
return mismatch_0
def test_exprs():
# save load json
x = tvm.tir.const(1, "int32")
y = tvm.tir.const(10, "int32")
vx = te.var("x")
vy = te.var("y")
vz = te.var("z")
zx = vx + vx
zy = vy + vy
assert consistent_equal(zx * zx, (vx + vx) * (vx + vx), map_free_vars=False)
# test assert trigger.
with pytest.raises(ValueError):
tvm.ir.assert_structural_equal(x, y)
assert not consistent_equal(vx, vy)
assert consistent_equal(vx, vy, map_free_vars=True)
# corner case lhs:vx == rhs:vy, but cannot map it iteslf
assert not consistent_equal(vx + vx, vy + vx, map_free_vars=True)
# corner case lhs:vx == rhs:vy, lhs:vy == rhs:vx
assert consistent_equal(vx + vy, vy + vx, map_free_vars=True)
# corner case2: rolling remap.
assert consistent_equal(vx + vy + vz, vy + vz + vx, map_free_vars=True)
assert not consistent_equal(vx + 1, vy + 1, map_free_vars=False)
# Defintition remap
assert consistent_equal(tvm.tir.Let(vx, 1, vx - 1), tvm.tir.Let(vy, 1, vy - 1))
# Default same address free var remap
assert consistent_equal(tvm.tir.Let(vx, 1, vx // vz), tvm.tir.Let(vy, 1, vy // vz))
assert consistent_equal(zx * zx, zx * zx)
assert consistent_equal(zx * zx, zy * zy, map_free_vars=True)
assert not consistent_equal(zx * zx, zy * zy, map_free_vars=False)
def test_prim_func():
x = te.var("x")
y = te.var("y")
# counter example of same equality
func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x + y))
func1 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(y + x))
assert not consistent_equal(func0, func1)
# new cases
b = tvm.tir.decl_buffer((x,), "float32")
stmt = tvm.tir.LetStmt(x, 10, tvm.tir.Evaluate(x + 1))
func0 = tvm.tir.PrimFunc([x, y, b], stmt)
# easiest way to deep copy is via save/load
func1 = tvm.ir.load_json(tvm.ir.save_json(func0))
tvm.ir.assert_structural_equal(func0, func1)
data0 = tvm.runtime.tensor([1, 2, 3])
data1 = tvm.runtime.tensor([1, 2, 3])
# attributes and ndarrays
func0 = func0.with_attr("data", data0)
func1 = func1.with_attr("data", data1)
# IRModules
mod0 = tvm.IRModule.from_expr(func0)
mod1 = tvm.IRModule.from_expr(func1)
tvm.ir.assert_structural_equal(mod0, mod1)
def test_prim_func_param_count_mismatch():
x = te.var("x")
y = te.var("y")
z = te.var("z")
# counter example of same equality
func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x))
func1 = tvm.tir.PrimFunc([x, y, z], tvm.tir.Evaluate(x))
lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
expected_lhs_path = AccessPath.root().attr("params").array_item_missing(2)
expected_rhs_path = AccessPath.root().attr("params").array_item(2)
assert lhs_path == expected_lhs_path
assert rhs_path == expected_rhs_path
def test_prim_func_param_dtype_mismatch():
x = te.var("x")
y_0 = te.var("y", dtype="int32")
y_1 = te.var("z", dtype="float32")
# counter example of same equality
func0 = tvm.tir.PrimFunc([x, y_0], tvm.tir.Evaluate(x))
func1 = tvm.tir.PrimFunc([x, y_1], tvm.tir.Evaluate(x))
lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
expected_path = AccessPath.root().attr("params").array_item(1).attr("dtype")
assert lhs_path == expected_path
assert rhs_path == expected_path
def test_prim_func_body_mismatch():
x_0 = te.var("x")
y_0 = te.var("y")
x_1 = te.var("x")
y_1 = te.var("y")
# counter example of same equality
func0 = tvm.tir.PrimFunc([x_0, y_0], tvm.tir.Evaluate(x_0 + x_0))
func1 = tvm.tir.PrimFunc([x_1, y_1], tvm.tir.Evaluate(x_1 + y_1))
lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
expected_path = AccessPath.root().attr("body").attr("value").attr("b")
assert lhs_path == expected_path
assert rhs_path == expected_path
def test_array():
x = np.arange(10)
nx = tvm.runtime.tensor(x)
ny = tvm.runtime.tensor(x)
nz = tvm.runtime.tensor(x.reshape(2, 5))
assert consistent_equal(nx, ny)
assert not consistent_equal(nx, nz)
def test_env_func():
@tvm.register_global_func("test.sequal.env_func")
def test(x):
return x + 1
x = tvm.ir.EnvFunc.get("test.sequal.env_func")
y = tvm.ir.EnvFunc.get("test.sequal.env_func")
assert consistent_equal(y, x)
def test_attrs():
x = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
y = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
z = tvm.ir.make_node("attrs.TestAttrs", axis=2, name="xx")
tvm.ir.assert_structural_equal(y, x)
assert not consistent_equal(y, z)
x = tvm.runtime.convert({"x": [1, 2, 3], "y": 2})
y = tvm.runtime.convert({"y": 2, "x": [1, 2, 3]})
z = tvm.runtime.convert({"y": 2, "x": [1, 2, 3, 4]})
assert consistent_equal(y, x)
assert not consistent_equal(y, z)
def test_stmt():
x = te.var("x")
y = te.var("y")
n = 128
A = te.placeholder((n, n), name="A")
B = te.placeholder((n, n), name="B")
ii = te.var("i")
jj = te.var("j")
Ab = tvm.tir.decl_buffer((n,), name="A")
n = te.var("n")
def func2():
ib = tvm.tir.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n, name="i") as i:
A[i] = A[i] + 1
with ib.for_range(0, 10, name="j") as j:
A[j] = A[j] + 2
A[j] = A[j] + 2
return ib.get()
assert consistent_equal(func2(), func2())
def test_buffer_storage_scope():
x = te.var("x", dtype="handle")
buffer_local_0 = tvm.tir.decl_buffer((10, 10), "float32", scope="local")
buffer_local_1 = tvm.tir.decl_buffer((10, 10), "float32", scope="local")
buffer_global = tvm.tir.decl_buffer((10, 10), "float32")
buffer_empty = tvm.tir.decl_buffer((10, 10), "float32", scope="")
func0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_local_0})
func1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_local_1})
func2 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_global})
func3 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_empty})
assert consistent_equal(func0, func1)
assert consistent_equal(func2, func3)
assert not consistent_equal(func0, func2)
def test_buffer_map_mismatch():
x = te.var("x")
buffer_0 = tvm.tir.decl_buffer((10, 10))
buffer_0_clone = tvm.tir.decl_buffer((10, 10))
buffer_1 = tvm.tir.decl_buffer((10, 20))
func_0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0})
func_0_clone = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0_clone})
func_1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_1})
lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1)
expected_path = (
AccessPath.root().attr("buffer_map").map_item(x).attr("shape").array_item(1).attr("value")
)
assert lhs_path == expected_path
assert rhs_path == expected_path
assert get_sequal_mismatch(func_0, func_0_clone) is None
def test_buffer_map_length_mismatch():
x = te.var("x")
y = te.var("x")
buffer_0 = tvm.tir.decl_buffer((10, 10))
buffer_1 = tvm.tir.decl_buffer((10, 20))
func_0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0})
func_1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0, y: buffer_1})
lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1)
expected_lhs_path = AccessPath.root().attr("buffer_map").map_item_missing(y)
assert lhs_path == expected_lhs_path
expected_rhs_path = AccessPath.root().attr("buffer_map").map_item(y)
assert rhs_path == expected_rhs_path
def test_buffer_load_store():
b = tvm.tir.decl_buffer((10, 10), "float32")
x = tvm.tir.BufferLoad(b, [0, 1])
y = tvm.tir.BufferLoad(b, [0, 1])
z = tvm.tir.BufferLoad(b, [1, 2])
assert consistent_equal(y, x)
assert not consistent_equal(y, z)
i = tvm.tir.Var("x", "int32")
sx = tvm.tir.BufferStore(b, 0.1, [0, i])
sy = tvm.tir.BufferStore(b, 0.1, [0, i])
sz = tvm.tir.BufferStore(b, 0.1, [1, i])
assert consistent_equal(sy, sx)
assert not consistent_equal(sy, sz)
def test_while():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
wx = tvm.tir.While(x > 0, tvm.tir.Evaluate(x))
wy = tvm.tir.While(y > 0, tvm.tir.Evaluate(y))
assert not consistent_equal(wx, wy)
assert consistent_equal(wx, wy, map_free_vars=True)
def test_while_condition_mismatch():
x = tvm.tir.Var("x", "int32")
w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x))
w_1 = tvm.tir.While(x < 0, tvm.tir.Evaluate(x))
lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1)
expected_path = AccessPath.root().attr("condition")
assert lhs_path == expected_path
assert rhs_path == expected_path
def test_while_body_mismatch():
x = tvm.tir.Var("x", "int32")
w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x))
w_1 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x + 1))
lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1)
expected_path = AccessPath.root().attr("body").attr("value")
assert lhs_path == expected_path
assert rhs_path == expected_path
def test_seq_mismatch():
x = tvm.tir.Var("x", "int32")
seq_0 = tvm.tir.SeqStmt(
[
tvm.tir.Evaluate(x),
tvm.tir.Evaluate(x + 1),
tvm.tir.Evaluate(x + 2),
tvm.tir.Evaluate(x + 3),
]
)
seq_1 = tvm.tir.SeqStmt(
[
tvm.tir.Evaluate(x),
tvm.tir.Evaluate(x + 1),
tvm.tir.Evaluate(x + 99),
tvm.tir.Evaluate(x + 3),
]
)
lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
expected_path = (
AccessPath.root().attr("seq").array_item(2).attr("value").attr("b").attr("value")
)
assert lhs_path == expected_path
assert rhs_path == expected_path
def test_seq_mismatch_different_lengths():
# Make sure we report a difference inside the array first, rather than the difference in length
x = tvm.tir.Var("x", "int32")
seq_0 = tvm.tir.SeqStmt(
[
tvm.tir.Evaluate(x),
tvm.tir.Evaluate(x + 1),
tvm.tir.Evaluate(x + 2),
tvm.tir.Evaluate(x + 3),
]
)
seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 3)])
lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
expected_path = (
AccessPath.root().attr("seq").array_item(2).attr("value").attr("b").attr("value")
)
assert lhs_path == expected_path
assert rhs_path == expected_path
def test_seq_length_mismatch():
x = tvm.tir.Var("x", "int32")
seq_0 = tvm.tir.SeqStmt(
[
tvm.tir.Evaluate(x),
tvm.tir.Evaluate(x + 1),
tvm.tir.Evaluate(x + 2),
tvm.tir.Evaluate(x + 3),
]
)
seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 2)])
lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
expected_lhs_path = AccessPath.root().attr("seq").array_item(3)
expected_rhs_path = AccessPath.root().attr("seq").array_item_missing(3)
assert lhs_path == expected_lhs_path
assert rhs_path == expected_rhs_path
def test_ir_module_equal():
def generate(n: int):
@I.ir_module
class module:
@T.prim_func
def func(A: T.Buffer(1, "int32")):
for i in range(n):
A[0] = A[0] + 1
return module
# Equivalent IRModules should compare as equivalent, even though
# they have distinct GlobalVars, and GlobalVars usually compare by
# reference equality.
tvm.ir.assert_structural_equal(generate(16), generate(16))
# When there is a difference, the location should include the
# function name that caused the failure.
with pytest.raises(ValueError) as err:
tvm.ir.assert_structural_equal(generate(16), generate(32))
assert '<root>.functions[I.GlobalVar("func")].body.extent.value' in err.value.args[0]
def test_nan_values_are_equivalent():
"""Structural equality treats two NaN values as equivalent.
By IEEE, a check of `NaN == NaN` returns false, as does
`abs(NaN - NaN) < tolerance`. However, for the purpose of
comparing IR representations, both NaN values are equivalent.
"""
@T.prim_func(private=True)
def func_1():
return T.float32("nan")
@T.prim_func(private=True)
def func_2():
return T.float32("nan")
tvm.ir.assert_structural_equal(func_1, func_2)
assert tvm.ir.structural_hash(func_1) == tvm.ir.structural_hash(func_2)
def test_all_nan_values_are_equivalent():
"""Structural equality treats two NaN values as equivalent.
IEEE defines NaN as any value that has all exponent bits set,
and has a non-zero mantissa. For the purposes of comparing IR
representations, all NaN values are considered equivalent.
"""
# A NaN with the first payload bit set.
nan_all_zeros = np.int32(0x7FC00000).view("float32")
# A NaN with the last payload bit set.
nan_with_payload = np.int32(0x7F800001).view("float32")
float_1 = T.float32(nan_all_zeros)
float_2 = T.float32(nan_with_payload)
tvm.ir.assert_structural_equal(float_1, float_2)
assert tvm.ir.structural_hash(float_1) == tvm.ir.structural_hash(float_2)
if __name__ == "__main__":
tvm.testing.main()