blob: 2d5afdc73f11a56f0775436ffe74b1ba18b9f02c [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests analysis functions of struct info"""
import pytest
import tvm
import tvm.testing
from tvm import TVMError
from tvm import relax as rx
from tvm import tir, ir
from tvm.script import relax as R, tir as T
def test_get_static_type_basic():
# object
s0 = rx.ObjectStructInfo()
tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s0), rx.ObjectType())
# prim
s1 = rx.PrimStructInfo("float32")
tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s1), tvm.ir.PrimType("float32"))
def test_get_static_type_shape():
# shape
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
s2 = rx.ShapeStructInfo([1, n + 1, m])
s3 = rx.ShapeStructInfo(ndim=2)
tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s2), rx.ShapeType(ndim=3))
tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s3), rx.ShapeType(ndim=2))
def test_get_static_type_tensor():
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
s4 = rx.TensorStructInfo([1, n + 1, m], "int64")
tvm.ir.assert_structural_equal(
rx.analysis.get_static_type(s4), rx.TensorType(ndim=3, dtype="int64")
)
def test_get_static_type_tuple():
# tuple
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
s0 = rx.ObjectStructInfo()
s2 = rx.ShapeStructInfo([1, n + 1, m])
s4 = rx.TensorStructInfo([1, n + 1, m], "int64")
t0 = rx.TupleStructInfo([s4, s0])
t1 = rx.TupleStructInfo([t0, s2])
tvm.ir.assert_structural_equal(
rx.analysis.get_static_type(t1),
rx.TupleType(
[
rx.TupleType([rx.TensorType(ndim=3, dtype="int64"), rx.ObjectType()]),
rx.ShapeType(ndim=3),
]
),
)
def test_get_static_type_func():
# tuple
def fn_info(c):
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.TensorStructInfo([c, n, m], "float32")
y = rx.TensorStructInfo([c, n, 1], "float32")
z = rx.TensorStructInfo([c, n], "float32")
return rx.FuncStructInfo([x, y], z)
def fn_type():
x = rx.TensorType(ndim=3, dtype="float32")
y = rx.TensorType(ndim=3, dtype="float32")
z = rx.TensorType(ndim=2, dtype="float32")
return rx.FuncType([x, y], z)
f0 = fn_info(1)
tvm.ir.assert_structural_equal(rx.analysis.get_static_type(fn_info(1)), fn_type())
def test_erase_to_well_defined_basic():
s0 = rx.ObjectStructInfo()
tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s0), s0)
# prim
s1 = rx.PrimStructInfo("float32")
tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1), s1)
def test_erase_to_well_defined_shape():
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
s2 = rx.ShapeStructInfo([1, n + 1, m])
s3 = rx.ShapeStructInfo(ndim=2)
# have undefined
tvm.ir.assert_structural_equal(
rx.analysis.erase_to_well_defined(s2), rx.ShapeStructInfo(ndim=3)
)
# all defined
tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2, {n: n, m: m}), s2)
# replacement
tvm.ir.assert_structural_equal(
rx.analysis.erase_to_well_defined(s2, {n: 2, m: m + 1}), rx.ShapeStructInfo([1, 3, m + 1])
)
# partial defined
tvm.ir.assert_structural_equal(
rx.analysis.erase_to_well_defined(s2, {n: n}), rx.ShapeStructInfo(ndim=3)
)
def test_erase_to_well_defined_tensor():
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
rshape = rx.Var("shape", rx.ShapeStructInfo(ndim=2))
s0 = rx.TensorStructInfo(rshape, dtype="int32")
# undefined
tvm.ir.assert_structural_equal(
rx.analysis.erase_to_well_defined(s0, None, None),
rx.TensorStructInfo(ndim=2, dtype="int32"),
)
# defined
tvm.ir.assert_structural_equal(
rx.analysis.erase_to_well_defined(s0, None, {rshape: rshape}), s0
)
tvm.ir.assert_structural_equal(
rx.analysis.erase_to_well_defined(s0, None, {rshape: rx.ShapeExpr([1, 2])}),
rx.TensorStructInfo([1, 2], dtype="int32"),
)
s1 = rx.TensorStructInfo([m + 1, n], dtype="float32")
tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1, {n: n, m: m}), s1)
tvm.ir.assert_structural_equal(
rx.analysis.erase_to_well_defined(s1, {n: 2, m: 3}),
rx.TensorStructInfo([4, 2], dtype="float32"),
)
tvm.ir.assert_structural_equal(
rx.analysis.erase_to_well_defined(s1, {m: m}, {rshape: rshape}),
rx.TensorStructInfo(ndim=2, dtype="float32"),
)
s2 = rx.TensorStructInfo([1, 2], dtype="float32")
tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2), s2)
def test_erase_to_well_defined_tuple():
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
s0 = rx.ObjectStructInfo()
s2 = rx.ShapeStructInfo([1, m])
s4 = rx.TensorStructInfo([1, n + 1, m], "int64")
t0 = rx.TupleStructInfo([s4, s0])
t1 = rx.TupleStructInfo([t0, s2])
tvm.ir.assert_structural_equal(
rx.analysis.erase_to_well_defined(t1, {m: m + 1}),
rx.TupleStructInfo(
[
rx.TupleStructInfo(
[rx.TensorStructInfo(ndim=3, dtype="int64"), rx.ObjectStructInfo()]
),
rx.ShapeStructInfo([1, m + 1]),
]
),
)
def test_erase_to_well_defined_func():
def fn_info(c):
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.TensorStructInfo([c, n, m], "float32")
y = rx.TensorStructInfo([c, n, 1], "float32")
z = rx.TensorStructInfo([c, n], "float32")
return rx.FuncStructInfo([x, y], z)
f0 = fn_info(1)
tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(f0), f0)
def test_base_check():
BR = rx.analysis.BaseCheckResult
bcheck = rx.analysis.struct_info_base_check
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
obj0 = rx.ObjectStructInfo()
prim0 = rx.PrimStructInfo("int32")
prim1 = rx.PrimStructInfo("float32")
shape0 = rx.ShapeStructInfo(ndim=-1)
shape1 = rx.ShapeStructInfo(ndim=2)
shape2 = rx.ShapeStructInfo(ndim=3)
shape3 = rx.ShapeStructInfo([1, 2, 3])
shape4 = rx.ShapeStructInfo([1, n, 3])
vdevice0 = ir.VDevice()
vdevice1 = ir.VDevice("llvm")
vdevice2 = ir.VDevice("cuda", 0)
vdevice3 = ir.VDevice("cuda", 2)
vdevice4 = ir.VDevice("cuda", 0, "")
tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32")
tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32")
tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32")
tensor3 = rx.TensorStructInfo(ndim=2, dtype="float32")
tensor4 = rx.TensorStructInfo([n, m], "int32")
tensor5 = rx.TensorStructInfo([n, m, 1], "int32")
tensor6 = rx.TensorStructInfo([n, m, 2], "int32")
tensor7 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice0)
tensor8 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice1)
tensor9 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice2)
tensor10 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice3)
tensor11 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice4)
tensor12 = rx.TensorStructInfo([n, m, 2], "int32", vdevice0)
tensor13 = rx.TensorStructInfo([n, m, 2], "int32", vdevice1)
tensor14 = rx.TensorStructInfo([n, m, 2], "int32", vdevice2)
tensor15 = rx.TensorStructInfo([n, m, 2], "int32", vdevice3)
tensor16 = rx.TensorStructInfo([n, m, 2], "int32", vdevice4)
# obj
assert bcheck(obj0, prim0) == BR.PASS
assert bcheck(obj0, shape1) == BR.PASS
assert bcheck(obj0, tensor2) == BR.PASS
assert obj0.is_base_of(tensor2)
# prim
assert prim0.is_base_of(prim0)
assert not prim0.is_base_of(prim1)
assert bcheck(prim0, obj0) == BR.FAIL_L1
assert bcheck(prim0, prim0) == BR.PASS
assert bcheck(prim0, prim1) == BR.FAIL_L0
# shape
assert bcheck(shape0, obj0) == BR.FAIL_L1
assert bcheck(shape0, prim0) == BR.FAIL_L0
# unknown dim
assert bcheck(shape0, shape1) == BR.PASS
assert bcheck(shape1, shape0) == BR.FAIL_L1
# ndim mismatch
assert bcheck(shape1, shape2) == BR.FAIL_L0
# lhs do not have symbolic value but ndim match
assert bcheck(shape2, shape3) == BR.PASS
# rhs do not symbolic but lhs do
assert bcheck(shape3, shape2) == BR.FAIL_L2
# shape mismatch
assert bcheck(shape3, shape4) == BR.FAIL_L2
assert shape4.is_base_of(rx.ShapeStructInfo([1, n, 3]))
# tensor
assert bcheck(tensor0, obj0) == BR.FAIL_L1
assert bcheck(tensor0, prim0) == BR.FAIL_L0
assert bcheck(tensor0, shape0) == BR.FAIL_L0
# dtype mismatch
assert bcheck(tensor0, tensor1) == BR.FAIL_L0
assert bcheck(tensor0, tensor3) == BR.FAIL_L0
assert bcheck(tensor3, tensor4) == BR.FAIL_L0
assert bcheck(tensor1, tensor2) == BR.FAIL_L0
# vdevice mismatch
assert bcheck(tensor8, tensor9) == BR.FAIL_L0
assert bcheck(tensor9, tensor10) == BR.FAIL_L0
assert bcheck(tensor10, tensor11) == BR.FAIL_L0
assert bcheck(tensor13, tensor14) == BR.FAIL_L0
assert bcheck(tensor14, tensor15) == BR.FAIL_L0
assert bcheck(tensor15, tensor16) == BR.FAIL_L0
# ndim mismatch
assert bcheck(tensor2, tensor5) == BR.FAIL_L0
# static shape mismatch
assert bcheck(tensor5, tensor6) == BR.FAIL_L0
# match
assert tensor0.is_base_of(rx.TensorStructInfo(ndim=-1, dtype="int32"))
assert tensor0.is_base_of(tensor2)
assert tensor0.is_base_of(tensor4)
assert tensor0.is_base_of(tensor5)
assert tensor0.is_base_of(tensor6)
assert tensor2.is_base_of(tensor4)
assert tensor3.is_base_of(tensor7)
assert tensor3.is_base_of(tensor8)
assert tensor6.is_base_of(tensor12)
assert tensor6.is_base_of(tensor13)
assert tensor4.is_base_of(rx.TensorStructInfo([n, m], dtype="int32"))
# tuple
t0 = rx.TupleStructInfo([obj0, tensor0])
t1 = rx.TupleStructInfo([prim0, tensor4])
t2 = rx.TupleStructInfo([obj0, tensor0, obj0])
t3 = rx.TupleStructInfo([tensor0, obj0])
assert t0.is_base_of(t1)
assert bcheck(t0, t2) == BR.FAIL_L0
assert bcheck(t0, t3) == BR.FAIL_L1
assert rx.TupleStructInfo([t0, t1]).is_base_of(rx.TupleStructInfo([t1, t1]))
assert bcheck(rx.TupleStructInfo([t0, t1]), rx.TupleStructInfo([t1, t0])) == BR.FAIL_L1
def fn_info_shape(c):
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.TensorStructInfo([c, n, m], "float32")
y = rx.TensorStructInfo([c, n, 1], "float32")
z = rx.TensorStructInfo([c, n], "float32")
return rx.FuncStructInfo([x, y], z)
def fn_info_erased():
x = rx.TensorStructInfo(ndim=3, dtype="float32")
y = rx.TensorStructInfo(ndim=3, dtype="float32")
z = rx.TensorStructInfo(ndim=2, dtype="float32")
return rx.FuncStructInfo([x, y], z)
assert fn_info_shape(1).is_base_of(fn_info_shape(1))
assert fn_info_erased().is_base_of(fn_info_shape(1))
assert bcheck(fn_info_shape(1), fn_info_erased()) == BR.FAIL_L2
fopaque = rx.FuncStructInfo.opaque_func()
assert fopaque.is_base_of(fn_info_shape(1))
def _check_derive(ctx, finfo, args_sinfo, ret):
gv = rx.GlobalVar("test")
rx.expr._update_struct_info(gv, finfo)
args = []
for i, sinfo in enumerate(args_sinfo):
arg = rx.Var("arg%i" % i, sinfo)
args.append(arg)
call = rx.Call(gv, args)
derived_ret = rx.analysis.derive_call_ret_struct_info(finfo, call, ctx)
tvm.ir.assert_structural_equal(ret, derived_ret)
def test_derive_call_ret_struct_info():
obj0 = rx.ObjectStructInfo()
prim0 = rx.PrimStructInfo("float32")
n, m = tir.Var("n0", "int64"), tir.Var("m0", "int64")
bb = rx.BlockBuilder()
# derivation cases
with bb.testing_scope(def_vars=[n, m]):
def func0(c):
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.TensorStructInfo([n, m], "float32")
z = rx.TensorStructInfo([m + c, n], "float32")
return rx.FuncStructInfo([x], z)
# Tensor => Tensor
_check_derive(
bb,
func0(1),
[rx.TensorStructInfo([10, 11], "float32")],
rx.TensorStructInfo([12, 10], "float32"),
)
_check_derive(
bb,
func0(2),
[rx.TensorStructInfo([n, m], "float32")],
rx.TensorStructInfo([m + 2, n], "float32"),
)
# passing in information that cannot deduce n, m
# it is still OK as type still matches, return an
# eriased output
_check_derive(
bb,
func0(2),
[rx.TensorStructInfo(ndim=2, dtype="float32")],
rx.TensorStructInfo(ndim=2, dtype="float32"),
)
# Error: wrong number of arguments
with pytest.raises(TVMError):
_check_derive(
bb,
func0(2),
[rx.TensorStructInfo(ndim=2, dtype="float32"), obj0],
rx.TensorStructInfo(ndim=2, dtype="float32"),
)
# Error:type mismatch
with pytest.raises(TVMError):
_check_derive(bb, func0(2), [obj0], obj0)
# Tensor with vdevice
vdev = ir.VDevice("llvm")
def func1(c):
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.TensorStructInfo([n, m], "float32", vdev)
z = rx.TensorStructInfo([m + c, n], "float32", vdev)
return rx.FuncStructInfo([x], z)
_check_derive(
bb,
func1(1),
[rx.TensorStructInfo([10, 11], "float32", vdev)],
rx.TensorStructInfo([12, 10], "float32", vdev),
)
# opaque derivation
fopaque0 = lambda: rx.FuncStructInfo.opaque_func()
fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0)
_check_derive(bb, fopaque0(), [obj0, prim0], obj0)
_check_derive(bb, fopaque1(), [obj0, prim0], prim0)
# recursive tuple derivation
def func_tuple0(c):
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x0 = rx.TensorStructInfo([n, c], "float32")
x1 = rx.TensorStructInfo([n + c, m], "float32")
z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")])
return rx.FuncStructInfo([rx.TupleStructInfo([x0, x1])], z)
_check_derive(
bb,
func_tuple0(2),
[
rx.TupleStructInfo(
[
rx.TensorStructInfo([n, 2], "float32"),
rx.TensorStructInfo([n + 2, 10], "float32"),
]
)
],
rx.TupleStructInfo([rx.TensorStructInfo([10, n], "float32")]),
)
def func_tuple1(c):
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x0 = rx.TensorStructInfo([n, m], "float32")
x1 = rx.TensorStructInfo([n + c, c], "float32")
z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")])
return rx.FuncStructInfo([rx.TupleStructInfo([x0, x1])], z)
# Still OK, to pass erased tensor into n+2, n is captured by other argument.
_check_derive(
bb,
func_tuple1(4),
[
rx.TupleStructInfo(
[
rx.TensorStructInfo([n, 4], "float32"),
rx.TensorStructInfo(ndim=2, dtype="float32"),
]
)
],
rx.TupleStructInfo([rx.TensorStructInfo([4, n], "float32")]),
)
# tuple length mismatch is not causes an error
with pytest.raises(TVMError):
_check_derive(
bb,
func_tuple0(4),
[rx.TupleStructInfo([rx.TensorStructInfo([n, 4], "float32")])],
rx.TupleStructInfo([rx.TensorStructInfo([10, n], "float32")]),
)
# mixed shape types
def func_shape_mixed(c):
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x0 = rx.ShapeStructInfo([n, m])
f0 = func_tuple0(c)
z = rx.ShapeStructInfo([m + n, c])
return rx.FuncStructInfo([x0, f0], z)
_check_derive(
bb,
func_shape_mixed(3),
[
rx.ShapeStructInfo([10, 20]),
# have to specify purity because an impure function cannot be passed
# where a pure one is expected
rx.FuncStructInfo.opaque_func(ret=rx.ShapeStructInfo(ndim=2), purity=True),
],
rx.ShapeStructInfo([30, 3]),
)
def _check_lca(lhs, rhs, target):
tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(lhs, rhs), target)
tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(rhs, lhs), target)
def test_struct_info_lca():
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
obj0 = rx.ObjectStructInfo()
prim0 = rx.PrimStructInfo("int32")
prim1 = rx.PrimStructInfo("float32")
vdevice0 = ir.VDevice("llvm")
vdevice1 = ir.VDevice("cuda", 0)
shape0 = rx.ShapeStructInfo(ndim=-1)
shape1 = rx.ShapeStructInfo(ndim=2)
shape2 = rx.ShapeStructInfo(ndim=3)
shape3 = rx.ShapeStructInfo([1, 2, 3])
shape4 = rx.ShapeStructInfo([1, n, 3])
tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32")
tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32")
tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32")
tensor3 = rx.TensorStructInfo(ndim=2, dtype="float32")
tensor4 = rx.TensorStructInfo([n, m], "int32")
tensor5 = rx.TensorStructInfo([n, m, 1], "int32")
tensor6 = rx.TensorStructInfo([n, m, 2], "int32")
tensor7 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice0)
tensor8 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice1)
tensor9 = rx.TensorStructInfo([n, m, 2], "int32", vdevice0)
tensor10 = rx.TensorStructInfo([n, m, 2], "int32", vdevice1)
# obj
_check_lca(obj0, prim0, obj0)
_check_lca(obj0, prim1, obj0)
# shape
_check_lca(shape0, tensor0, obj0)
_check_lca(shape0, shape1, shape0)
_check_lca(shape1, shape2, shape0)
_check_lca(shape1, shape3, shape0)
_check_lca(shape2, shape3, shape2)
_check_lca(shape3, shape4, shape2)
_check_lca(shape4, rx.ShapeStructInfo([1, n, 3]), shape4)
# tensor
_check_lca(tensor0, prim0, obj0)
_check_lca(tensor0, tensor1, rx.TensorStructInfo(ndim=-1, dtype=None))
_check_lca(tensor0, tensor2, tensor0)
_check_lca(tensor0, tensor4, tensor0)
_check_lca(tensor0, tensor4, tensor0)
_check_lca(tensor1, tensor3, tensor1)
_check_lca(tensor3, tensor7, tensor3)
_check_lca(tensor3, tensor8, tensor3)
_check_lca(tensor1, tensor8, tensor1)
_check_lca(tensor6, tensor9, tensor6)
_check_lca(tensor6, tensor10, tensor6)
_check_lca(tensor2, tensor4, tensor2)
_check_lca(tensor5, tensor6, rx.TensorStructInfo(ndim=3, dtype="int32"))
_check_lca(tensor4, tensor5, rx.TensorStructInfo(ndim=-1, dtype="int32"))
_check_lca(tensor4, rx.TensorStructInfo([n, m], dtype="int32"), tensor4)
# tuple
t0 = rx.TupleStructInfo([obj0, tensor0])
t1 = rx.TupleStructInfo([prim0, tensor4])
t2 = rx.TupleStructInfo([obj0, tensor0, obj0])
t3 = rx.TupleStructInfo([tensor0, obj0])
_check_lca(t0, t1, t0)
_check_lca(t0, t2, obj0)
_check_lca(t0, t3, rx.TupleStructInfo([obj0, obj0]))
t5 = rx.TupleStructInfo([t0, t1])
t6 = rx.TupleStructInfo([t1, t2])
_check_lca(t5, t6, rx.TupleStructInfo([t0, obj0]))
t7 = rx.TupleStructInfo([])
_check_lca(t7, rx.TupleStructInfo([]), t7)
def fn_info_shape(c):
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.TensorStructInfo([c, n, m], "float32")
y = rx.TensorStructInfo([c, n, 1], "float32")
z = rx.TensorStructInfo([c, n], "float32")
return rx.FuncStructInfo([x, y], z)
def fn_info_erased():
x = rx.TensorStructInfo(ndim=3, dtype="float32")
y = rx.TensorStructInfo(ndim=3, dtype="float32")
z = rx.TensorStructInfo(ndim=2, dtype="float32")
return rx.FuncStructInfo([x, y], z)
fopaque0 = lambda: rx.FuncStructInfo.opaque_func()
fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0)
fopaque2 = lambda: rx.FuncStructInfo.opaque_func(
ret=rx.TensorStructInfo(ndim=2, dtype="float32")
)
_check_lca(fn_info_shape(1), fn_info_shape(2), fn_info_erased())
_check_lca(fn_info_shape(2), fn_info_shape(2), fn_info_shape(2))
_check_lca(fopaque0(), fopaque1(), fopaque0())
_check_lca(fopaque0(), fn_info_shape(1), fopaque0())
_check_lca(fopaque2(), fn_info_shape(1), fopaque2())
def _generate_prim_test_cases():
dtypes = [
"bool",
"int8",
"uint8",
"int16",
"uint16",
"int32",
"uint32",
"int64",
"uint64",
"float16",
"float32",
"float64",
]
for dtype in dtypes:
# LCA of a PrimStructInfo with itself yields itself
yield (R.Prim(dtype), R.Prim(dtype), R.Prim(dtype))
# The LCA of two values, each statically known to be the same
# value, is known to have that value.
yield (
R.Prim(value=tir.const(0, dtype)),
R.Prim(value=tir.const(0, dtype)),
R.Prim(value=tir.const(0, dtype)),
)
# The LCA of two values, each of which is statically known to
# have a different value, no longer knows the contained value.
yield (
R.Prim(value=tir.const(0, dtype)),
R.Prim(value=tir.const(1, dtype)),
R.Prim(dtype=dtype),
)
# LCA of a known variable with itself yields itself
var_N = tir.Var("N", dtype)
yield (R.Prim(value=var_N), R.Prim(value=var_N), R.Prim(value=var_N))
# LCA of a known variable with a known static value is no
# longer known to have a specific value.
yield (R.Prim(value=var_N), R.Prim(value=tir.const(0, dtype)), R.Prim(dtype=dtype))
yield (R.Prim(value=tir.const(0, dtype)), R.Prim(value=var_N), R.Prim(dtype=dtype))
var_M = tir.Var("M", dtype)
yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Prim(dtype=dtype))
for dtype_a in dtypes:
for dtype_b in dtypes:
if dtype_a != dtype_b:
# Unlike R.Tensor, R.Prim does not currently support a
# value with an unknown datatype. If the dtype
# differs between the two annotations, the next wider
# category is R.Object.
yield (R.Prim(dtype_a), R.Prim(dtype_b), R.Object)
# Because the dtypes are different, even `R.Prim` containing
# the same value in different representations (e.g.
# `T.float32(0)` vs `T.float16(0)`) fall back to `R.Object`.
yield (
R.Prim(value=tir.const(0, dtype_a)),
R.Prim(value=tir.const(0, dtype_b)),
R.Object,
)
# And the same is true for known variable values
var_N = tir.Var("N", dtype_a)
var_M = tir.Var("M", dtype_b)
yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Object)
@pytest.mark.parametrize("test_case", list(_generate_prim_test_cases()))
def test_prim_struct_info_lca(test_case):
def _normalize_sinfo(sinfo):
if isinstance(sinfo, tvm.relax.StructInfo):
return sinfo
elif isinstance(sinfo, tvm.script.parser.relax.entry.StructInfoProxy):
return sinfo.as_struct_info()
elif callable(sinfo):
return sinfo()
else:
raise TypeError(f"Cannot normalize {type(sinfo)} to StructInfo")
lhs, rhs, expected = map(_normalize_sinfo, test_case)
lca = rx.analysis.struct_info_lca(lhs, rhs)
assert tvm.ir.structural_equal(
lca, expected
), f"Expected {lhs} and {rhs} to have LCA of {expected}, but instead found {lca}"
def _generate_tir_var_test_cases():
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
shape0 = rx.ShapeStructInfo([1, n, 3])
shape1 = rx.ShapeStructInfo([1, 2 * n, n, m])
shape2 = rx.ShapeStructInfo([1, 2 * n, m])
tensor0 = rx.TensorStructInfo([1, n, 3], "int32")
tensor1 = rx.TensorStructInfo([1, 2 * n, n, m], "int32")
tensor2 = rx.TensorStructInfo([1, 2 * n, m], "int32")
func = rx.FuncStructInfo(
[rx.TensorStructInfo([1, 2 * n, n, m], "int32")], rx.TensorStructInfo([1, n, 3], "int32")
)
yield shape0, [n], [n]
yield shape1, [n, m], [n, m]
yield shape2, [m], [n, m]
yield tensor0, [n], [n]
yield tensor1, [n, m], [n, m]
yield tensor2, [m], [n, m]
yield func, [n, m], [n, m]
tir_var_test_case = tvm.testing.parameter(*_generate_tir_var_test_cases())
def test_tir_vars_in_struct_info(tir_var_test_case):
sinfo, _vars_definable, vars_used = tir_var_test_case
tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(sinfo), vars_used)
def test_definable_tir_vars_in_struct_info(tir_var_test_case):
sinfo, vars_definable, _vars_used = tir_var_test_case
tvm.ir.assert_structural_equal(
rx.analysis.definable_tir_vars_in_struct_info(sinfo), vars_definable
)
def test_collect_symbolic_var_from_tensor_shape():
n, m, k, q, p = (
tir.Var("n", "int64"),
tir.Var("m", "int64"),
tir.Var("k", "int64"),
tir.Var("q", "int64"),
tir.Var("p", "int64"),
)
bb = rx.BlockBuilder()
x = rx.Var("x", rx.TensorStructInfo([m, m + n], "float32"))
with bb.function("main", [x]):
v0 = bb.match_cast(x, rx.TensorStructInfo([m, k], "float32"))
v1 = bb.emit(rx.call_dps_packed("test", x, rx.TensorStructInfo([p, q], "float32")))
bb.emit_func_output(rx.const(1))
func = bb.get()["main"]
defined_vars = set(rx.analysis.defined_symbolic_vars(func))
free_vars = set(rx.analysis.free_symbolic_vars(func))
assert defined_vars == {m, k}
assert free_vars == {n, p, q}
param_type = tvm.testing.parameter("shape_expr", "prim_value")
param_order = tvm.testing.parameter("definition_first", "usage_first")
def test_collect_symbolic_var_from_non_tensor_params(param_type, param_order):
tir_n = tir.Var("n", "int64")
tir_m = tir.Var("m", "int64")
bb = rx.BlockBuilder()
arg = rx.Var("arg", rx.TensorStructInfo([tir_n * tir_m]))
if param_type == "shape_expr":
extra_params = [
rx.Var("shape_expr", rx.ShapeStructInfo([tir_n, tir_m])),
]
elif param_type == "prim_value":
extra_params = [
rx.Var("n", rx.PrimStructInfo(value=tir_n)),
rx.Var("m", rx.PrimStructInfo(value=tir_m)),
]
else:
raise ValueError(f"Unknown param_type: {param_type}")
if param_order == "definition_first":
params = [*extra_params, arg]
elif param_order == "usage_first":
params = [arg, *extra_params]
else:
raise ValueError(f"Unknown param_order: {param_order}")
with bb.function("main", params=params):
out = rx.op.reshape(arg, [tir_n, tir_m])
bb.emit_func_output(out)
func = bb.get()["main"]
defined_vars = set(rx.analysis.defined_symbolic_vars(func))
free_vars = set(rx.analysis.free_symbolic_vars(func))
assert defined_vars == {tir_n, tir_m}
assert free_vars == set()
def test_collect_nonnegative_expressions():
@R.function
def func(
A: R.Tensor([1024, "M", "N-2"]),
B: R.Tensor([128, "N", "M+2"]),
C: R.Shape(["M", "N"]),
D: R.Prim(value="N"),
):
return R.tuple()
M, N = list(func.params[2].struct_info.values)
# Expressions are de-duplicated, in order of their first appearance
tvm.ir.assert_structural_equal(
rx.analysis.collect_non_negative_expressions(func.struct_info),
[M, N - 2, N, M + 2],
)
# Tensor shapes can imply that their shapes are non-negative
tvm.ir.assert_structural_equal(
rx.analysis.collect_non_negative_expressions(func.params[0].struct_info),
[M, N - 2],
)
tvm.ir.assert_structural_equal(
rx.analysis.collect_non_negative_expressions(func.params[1].struct_info),
[N, M + 2],
)
# ShapeExpr values can imply that their contents are non-negative
tvm.ir.assert_structural_equal(
rx.analysis.collect_non_negative_expressions(func.params[2].struct_info),
[M, N],
)
# PrimValue instances may contain negative values, and do not
# imply that their contents are non-negative.
tvm.ir.assert_structural_equal(
rx.analysis.collect_non_negative_expressions(func.params[3].struct_info),
[],
)
if __name__ == "__main__":
tvm.testing.main()