| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """Tests analysis functions of struct info""" |
| |
| import pytest |
| |
| import tvm |
| import tvm.testing |
| from tvm import TVMError |
| from tvm import relax as rx |
| from tvm import tir, ir |
| from tvm.script import relax as R, tir as T |
| |
| |
| def test_get_static_type_basic(): |
| # object |
| s0 = rx.ObjectStructInfo() |
| tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s0), rx.ObjectType()) |
| |
| # prim |
| s1 = rx.PrimStructInfo("float32") |
| tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s1), tvm.ir.PrimType("float32")) |
| |
| |
| def test_get_static_type_shape(): |
| # shape |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| |
| s2 = rx.ShapeStructInfo([1, n + 1, m]) |
| s3 = rx.ShapeStructInfo(ndim=2) |
| |
| tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s2), rx.ShapeType(ndim=3)) |
| |
| tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s3), rx.ShapeType(ndim=2)) |
| |
| |
| def test_get_static_type_tensor(): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| s4 = rx.TensorStructInfo([1, n + 1, m], "int64") |
| |
| tvm.ir.assert_structural_equal( |
| rx.analysis.get_static_type(s4), rx.TensorType(ndim=3, dtype="int64") |
| ) |
| |
| |
| def test_get_static_type_tuple(): |
| # tuple |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| s0 = rx.ObjectStructInfo() |
| s2 = rx.ShapeStructInfo([1, n + 1, m]) |
| s4 = rx.TensorStructInfo([1, n + 1, m], "int64") |
| t0 = rx.TupleStructInfo([s4, s0]) |
| t1 = rx.TupleStructInfo([t0, s2]) |
| |
| tvm.ir.assert_structural_equal( |
| rx.analysis.get_static_type(t1), |
| rx.TupleType( |
| [ |
| rx.TupleType([rx.TensorType(ndim=3, dtype="int64"), rx.ObjectType()]), |
| rx.ShapeType(ndim=3), |
| ] |
| ), |
| ) |
| |
| |
| def test_get_static_type_func(): |
| # tuple |
| def fn_info(c): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| x = rx.TensorStructInfo([c, n, m], "float32") |
| y = rx.TensorStructInfo([c, n, 1], "float32") |
| z = rx.TensorStructInfo([c, n], "float32") |
| return rx.FuncStructInfo([x, y], z) |
| |
| def fn_type(): |
| x = rx.TensorType(ndim=3, dtype="float32") |
| y = rx.TensorType(ndim=3, dtype="float32") |
| z = rx.TensorType(ndim=2, dtype="float32") |
| return rx.FuncType([x, y], z) |
| |
| f0 = fn_info(1) |
| |
| tvm.ir.assert_structural_equal(rx.analysis.get_static_type(fn_info(1)), fn_type()) |
| |
| |
| def test_erase_to_well_defined_basic(): |
| s0 = rx.ObjectStructInfo() |
| tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s0), s0) |
| |
| # prim |
| s1 = rx.PrimStructInfo("float32") |
| tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1), s1) |
| |
| |
| def test_erase_to_well_defined_shape(): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| |
| s2 = rx.ShapeStructInfo([1, n + 1, m]) |
| s3 = rx.ShapeStructInfo(ndim=2) |
| # have undefined |
| tvm.ir.assert_structural_equal( |
| rx.analysis.erase_to_well_defined(s2), rx.ShapeStructInfo(ndim=3) |
| ) |
| # all defined |
| tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2, {n: n, m: m}), s2) |
| |
| # replacement |
| tvm.ir.assert_structural_equal( |
| rx.analysis.erase_to_well_defined(s2, {n: 2, m: m + 1}), rx.ShapeStructInfo([1, 3, m + 1]) |
| ) |
| |
| # partial defined |
| tvm.ir.assert_structural_equal( |
| rx.analysis.erase_to_well_defined(s2, {n: n}), rx.ShapeStructInfo(ndim=3) |
| ) |
| |
| |
| def test_erase_to_well_defined_tensor(): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| rshape = rx.Var("shape", rx.ShapeStructInfo(ndim=2)) |
| s0 = rx.TensorStructInfo(rshape, dtype="int32") |
| |
| # undefined |
| tvm.ir.assert_structural_equal( |
| rx.analysis.erase_to_well_defined(s0, None, None), |
| rx.TensorStructInfo(ndim=2, dtype="int32"), |
| ) |
| |
| # defined |
| tvm.ir.assert_structural_equal( |
| rx.analysis.erase_to_well_defined(s0, None, {rshape: rshape}), s0 |
| ) |
| |
| tvm.ir.assert_structural_equal( |
| rx.analysis.erase_to_well_defined(s0, None, {rshape: rx.ShapeExpr([1, 2])}), |
| rx.TensorStructInfo([1, 2], dtype="int32"), |
| ) |
| |
| s1 = rx.TensorStructInfo([m + 1, n], dtype="float32") |
| |
| tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1, {n: n, m: m}), s1) |
| |
| tvm.ir.assert_structural_equal( |
| rx.analysis.erase_to_well_defined(s1, {n: 2, m: 3}), |
| rx.TensorStructInfo([4, 2], dtype="float32"), |
| ) |
| |
| tvm.ir.assert_structural_equal( |
| rx.analysis.erase_to_well_defined(s1, {m: m}, {rshape: rshape}), |
| rx.TensorStructInfo(ndim=2, dtype="float32"), |
| ) |
| |
| s2 = rx.TensorStructInfo([1, 2], dtype="float32") |
| |
| tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2), s2) |
| |
| |
| def test_erase_to_well_defined_tuple(): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| s0 = rx.ObjectStructInfo() |
| s2 = rx.ShapeStructInfo([1, m]) |
| s4 = rx.TensorStructInfo([1, n + 1, m], "int64") |
| t0 = rx.TupleStructInfo([s4, s0]) |
| t1 = rx.TupleStructInfo([t0, s2]) |
| |
| tvm.ir.assert_structural_equal( |
| rx.analysis.erase_to_well_defined(t1, {m: m + 1}), |
| rx.TupleStructInfo( |
| [ |
| rx.TupleStructInfo( |
| [rx.TensorStructInfo(ndim=3, dtype="int64"), rx.ObjectStructInfo()] |
| ), |
| rx.ShapeStructInfo([1, m + 1]), |
| ] |
| ), |
| ) |
| |
| |
| def test_erase_to_well_defined_func(): |
| def fn_info(c): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| x = rx.TensorStructInfo([c, n, m], "float32") |
| y = rx.TensorStructInfo([c, n, 1], "float32") |
| z = rx.TensorStructInfo([c, n], "float32") |
| return rx.FuncStructInfo([x, y], z) |
| |
| f0 = fn_info(1) |
| |
| tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(f0), f0) |
| |
| |
| def test_base_check(): |
| BR = rx.analysis.BaseCheckResult |
| bcheck = rx.analysis.struct_info_base_check |
| |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| obj0 = rx.ObjectStructInfo() |
| prim0 = rx.PrimStructInfo("int32") |
| prim1 = rx.PrimStructInfo("float32") |
| |
| shape0 = rx.ShapeStructInfo(ndim=-1) |
| shape1 = rx.ShapeStructInfo(ndim=2) |
| shape2 = rx.ShapeStructInfo(ndim=3) |
| shape3 = rx.ShapeStructInfo([1, 2, 3]) |
| shape4 = rx.ShapeStructInfo([1, n, 3]) |
| |
| vdevice0 = ir.VDevice() |
| vdevice1 = ir.VDevice("llvm") |
| vdevice2 = ir.VDevice("cuda", 0) |
| vdevice3 = ir.VDevice("cuda", 2) |
| vdevice4 = ir.VDevice("cuda", 0, "") |
| |
| tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32") |
| tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32") |
| tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32") |
| tensor3 = rx.TensorStructInfo(ndim=2, dtype="float32") |
| tensor4 = rx.TensorStructInfo([n, m], "int32") |
| tensor5 = rx.TensorStructInfo([n, m, 1], "int32") |
| tensor6 = rx.TensorStructInfo([n, m, 2], "int32") |
| tensor7 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice0) |
| tensor8 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice1) |
| tensor9 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice2) |
| tensor10 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice3) |
| tensor11 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice4) |
| tensor12 = rx.TensorStructInfo([n, m, 2], "int32", vdevice0) |
| tensor13 = rx.TensorStructInfo([n, m, 2], "int32", vdevice1) |
| tensor14 = rx.TensorStructInfo([n, m, 2], "int32", vdevice2) |
| tensor15 = rx.TensorStructInfo([n, m, 2], "int32", vdevice3) |
| tensor16 = rx.TensorStructInfo([n, m, 2], "int32", vdevice4) |
| |
| # obj |
| assert bcheck(obj0, prim0) == BR.PASS |
| assert bcheck(obj0, shape1) == BR.PASS |
| assert bcheck(obj0, tensor2) == BR.PASS |
| assert obj0.is_base_of(tensor2) |
| |
| # prim |
| assert prim0.is_base_of(prim0) |
| assert not prim0.is_base_of(prim1) |
| assert bcheck(prim0, obj0) == BR.FAIL_L1 |
| assert bcheck(prim0, prim0) == BR.PASS |
| assert bcheck(prim0, prim1) == BR.FAIL_L0 |
| |
| # shape |
| assert bcheck(shape0, obj0) == BR.FAIL_L1 |
| assert bcheck(shape0, prim0) == BR.FAIL_L0 |
| |
| # unknown dim |
| assert bcheck(shape0, shape1) == BR.PASS |
| assert bcheck(shape1, shape0) == BR.FAIL_L1 |
| |
| # ndim mismatch |
| assert bcheck(shape1, shape2) == BR.FAIL_L0 |
| |
| # lhs do not have symbolic value but ndim match |
| assert bcheck(shape2, shape3) == BR.PASS |
| |
| # rhs do not symbolic but lhs do |
| assert bcheck(shape3, shape2) == BR.FAIL_L2 |
| |
| # shape mismatch |
| assert bcheck(shape3, shape4) == BR.FAIL_L2 |
| assert shape4.is_base_of(rx.ShapeStructInfo([1, n, 3])) |
| |
| # tensor |
| assert bcheck(tensor0, obj0) == BR.FAIL_L1 |
| assert bcheck(tensor0, prim0) == BR.FAIL_L0 |
| assert bcheck(tensor0, shape0) == BR.FAIL_L0 |
| |
| # dtype mismatch |
| assert bcheck(tensor0, tensor1) == BR.FAIL_L0 |
| assert bcheck(tensor0, tensor3) == BR.FAIL_L0 |
| assert bcheck(tensor3, tensor4) == BR.FAIL_L0 |
| assert bcheck(tensor1, tensor2) == BR.FAIL_L0 |
| |
| # vdevice mismatch |
| assert bcheck(tensor8, tensor9) == BR.FAIL_L0 |
| assert bcheck(tensor9, tensor10) == BR.FAIL_L0 |
| assert bcheck(tensor10, tensor11) == BR.FAIL_L0 |
| assert bcheck(tensor13, tensor14) == BR.FAIL_L0 |
| assert bcheck(tensor14, tensor15) == BR.FAIL_L0 |
| assert bcheck(tensor15, tensor16) == BR.FAIL_L0 |
| |
| # ndim mismatch |
| assert bcheck(tensor2, tensor5) == BR.FAIL_L0 |
| |
| # static shape mismatch |
| assert bcheck(tensor5, tensor6) == BR.FAIL_L0 |
| |
| # match |
| assert tensor0.is_base_of(rx.TensorStructInfo(ndim=-1, dtype="int32")) |
| assert tensor0.is_base_of(tensor2) |
| assert tensor0.is_base_of(tensor4) |
| assert tensor0.is_base_of(tensor5) |
| assert tensor0.is_base_of(tensor6) |
| assert tensor2.is_base_of(tensor4) |
| assert tensor3.is_base_of(tensor7) |
| assert tensor3.is_base_of(tensor8) |
| assert tensor6.is_base_of(tensor12) |
| assert tensor6.is_base_of(tensor13) |
| assert tensor4.is_base_of(rx.TensorStructInfo([n, m], dtype="int32")) |
| |
| # tuple |
| t0 = rx.TupleStructInfo([obj0, tensor0]) |
| t1 = rx.TupleStructInfo([prim0, tensor4]) |
| t2 = rx.TupleStructInfo([obj0, tensor0, obj0]) |
| t3 = rx.TupleStructInfo([tensor0, obj0]) |
| |
| assert t0.is_base_of(t1) |
| |
| assert bcheck(t0, t2) == BR.FAIL_L0 |
| assert bcheck(t0, t3) == BR.FAIL_L1 |
| |
| assert rx.TupleStructInfo([t0, t1]).is_base_of(rx.TupleStructInfo([t1, t1])) |
| assert bcheck(rx.TupleStructInfo([t0, t1]), rx.TupleStructInfo([t1, t0])) == BR.FAIL_L1 |
| |
| def fn_info_shape(c): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| x = rx.TensorStructInfo([c, n, m], "float32") |
| y = rx.TensorStructInfo([c, n, 1], "float32") |
| z = rx.TensorStructInfo([c, n], "float32") |
| return rx.FuncStructInfo([x, y], z) |
| |
| def fn_info_erased(): |
| x = rx.TensorStructInfo(ndim=3, dtype="float32") |
| y = rx.TensorStructInfo(ndim=3, dtype="float32") |
| z = rx.TensorStructInfo(ndim=2, dtype="float32") |
| return rx.FuncStructInfo([x, y], z) |
| |
| assert fn_info_shape(1).is_base_of(fn_info_shape(1)) |
| assert fn_info_erased().is_base_of(fn_info_shape(1)) |
| assert bcheck(fn_info_shape(1), fn_info_erased()) == BR.FAIL_L2 |
| |
| fopaque = rx.FuncStructInfo.opaque_func() |
| assert fopaque.is_base_of(fn_info_shape(1)) |
| |
| |
| def _check_derive(ctx, finfo, args_sinfo, ret): |
| gv = rx.GlobalVar("test") |
| rx.expr._update_struct_info(gv, finfo) |
| args = [] |
| for i, sinfo in enumerate(args_sinfo): |
| arg = rx.Var("arg%i" % i, sinfo) |
| args.append(arg) |
| call = rx.Call(gv, args) |
| derived_ret = rx.analysis.derive_call_ret_struct_info(finfo, call, ctx) |
| tvm.ir.assert_structural_equal(ret, derived_ret) |
| |
| |
| def test_derive_call_ret_struct_info(): |
| obj0 = rx.ObjectStructInfo() |
| prim0 = rx.PrimStructInfo("float32") |
| |
| n, m = tir.Var("n0", "int64"), tir.Var("m0", "int64") |
| bb = rx.BlockBuilder() |
| # derivation cases |
| with bb.testing_scope(def_vars=[n, m]): |
| |
| def func0(c): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| x = rx.TensorStructInfo([n, m], "float32") |
| z = rx.TensorStructInfo([m + c, n], "float32") |
| return rx.FuncStructInfo([x], z) |
| |
| # Tensor => Tensor |
| _check_derive( |
| bb, |
| func0(1), |
| [rx.TensorStructInfo([10, 11], "float32")], |
| rx.TensorStructInfo([12, 10], "float32"), |
| ) |
| |
| _check_derive( |
| bb, |
| func0(2), |
| [rx.TensorStructInfo([n, m], "float32")], |
| rx.TensorStructInfo([m + 2, n], "float32"), |
| ) |
| |
| # passing in information that cannot deduce n, m |
| # it is still OK as type still matches, return an |
| # eriased output |
| _check_derive( |
| bb, |
| func0(2), |
| [rx.TensorStructInfo(ndim=2, dtype="float32")], |
| rx.TensorStructInfo(ndim=2, dtype="float32"), |
| ) |
| |
| # Error: wrong number of arguments |
| with pytest.raises(TVMError): |
| _check_derive( |
| bb, |
| func0(2), |
| [rx.TensorStructInfo(ndim=2, dtype="float32"), obj0], |
| rx.TensorStructInfo(ndim=2, dtype="float32"), |
| ) |
| |
| # Error:type mismatch |
| with pytest.raises(TVMError): |
| _check_derive(bb, func0(2), [obj0], obj0) |
| |
| # Tensor with vdevice |
| vdev = ir.VDevice("llvm") |
| |
| def func1(c): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| x = rx.TensorStructInfo([n, m], "float32", vdev) |
| z = rx.TensorStructInfo([m + c, n], "float32", vdev) |
| return rx.FuncStructInfo([x], z) |
| |
| _check_derive( |
| bb, |
| func1(1), |
| [rx.TensorStructInfo([10, 11], "float32", vdev)], |
| rx.TensorStructInfo([12, 10], "float32", vdev), |
| ) |
| |
| # opaque derivation |
| fopaque0 = lambda: rx.FuncStructInfo.opaque_func() |
| fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0) |
| _check_derive(bb, fopaque0(), [obj0, prim0], obj0) |
| _check_derive(bb, fopaque1(), [obj0, prim0], prim0) |
| |
| # recursive tuple derivation |
| def func_tuple0(c): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| x0 = rx.TensorStructInfo([n, c], "float32") |
| x1 = rx.TensorStructInfo([n + c, m], "float32") |
| z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")]) |
| return rx.FuncStructInfo([rx.TupleStructInfo([x0, x1])], z) |
| |
| _check_derive( |
| bb, |
| func_tuple0(2), |
| [ |
| rx.TupleStructInfo( |
| [ |
| rx.TensorStructInfo([n, 2], "float32"), |
| rx.TensorStructInfo([n + 2, 10], "float32"), |
| ] |
| ) |
| ], |
| rx.TupleStructInfo([rx.TensorStructInfo([10, n], "float32")]), |
| ) |
| |
| def func_tuple1(c): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| x0 = rx.TensorStructInfo([n, m], "float32") |
| x1 = rx.TensorStructInfo([n + c, c], "float32") |
| z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")]) |
| return rx.FuncStructInfo([rx.TupleStructInfo([x0, x1])], z) |
| |
| # Still OK, to pass erased tensor into n+2, n is captured by other argument. |
| _check_derive( |
| bb, |
| func_tuple1(4), |
| [ |
| rx.TupleStructInfo( |
| [ |
| rx.TensorStructInfo([n, 4], "float32"), |
| rx.TensorStructInfo(ndim=2, dtype="float32"), |
| ] |
| ) |
| ], |
| rx.TupleStructInfo([rx.TensorStructInfo([4, n], "float32")]), |
| ) |
| |
| # tuple length mismatch is not causes an error |
| with pytest.raises(TVMError): |
| _check_derive( |
| bb, |
| func_tuple0(4), |
| [rx.TupleStructInfo([rx.TensorStructInfo([n, 4], "float32")])], |
| rx.TupleStructInfo([rx.TensorStructInfo([10, n], "float32")]), |
| ) |
| |
| # mixed shape types |
| def func_shape_mixed(c): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| x0 = rx.ShapeStructInfo([n, m]) |
| f0 = func_tuple0(c) |
| z = rx.ShapeStructInfo([m + n, c]) |
| return rx.FuncStructInfo([x0, f0], z) |
| |
| _check_derive( |
| bb, |
| func_shape_mixed(3), |
| [ |
| rx.ShapeStructInfo([10, 20]), |
| # have to specify purity because an impure function cannot be passed |
| # where a pure one is expected |
| rx.FuncStructInfo.opaque_func(ret=rx.ShapeStructInfo(ndim=2), purity=True), |
| ], |
| rx.ShapeStructInfo([30, 3]), |
| ) |
| |
| |
| def _check_lca(lhs, rhs, target): |
| tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(lhs, rhs), target) |
| tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(rhs, lhs), target) |
| |
| |
| def test_struct_info_lca(): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| obj0 = rx.ObjectStructInfo() |
| prim0 = rx.PrimStructInfo("int32") |
| prim1 = rx.PrimStructInfo("float32") |
| |
| vdevice0 = ir.VDevice("llvm") |
| vdevice1 = ir.VDevice("cuda", 0) |
| |
| shape0 = rx.ShapeStructInfo(ndim=-1) |
| shape1 = rx.ShapeStructInfo(ndim=2) |
| shape2 = rx.ShapeStructInfo(ndim=3) |
| shape3 = rx.ShapeStructInfo([1, 2, 3]) |
| shape4 = rx.ShapeStructInfo([1, n, 3]) |
| |
| tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32") |
| tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32") |
| tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32") |
| tensor3 = rx.TensorStructInfo(ndim=2, dtype="float32") |
| tensor4 = rx.TensorStructInfo([n, m], "int32") |
| tensor5 = rx.TensorStructInfo([n, m, 1], "int32") |
| tensor6 = rx.TensorStructInfo([n, m, 2], "int32") |
| tensor7 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice0) |
| tensor8 = rx.TensorStructInfo(ndim=2, dtype="float32", vdevice=vdevice1) |
| tensor9 = rx.TensorStructInfo([n, m, 2], "int32", vdevice0) |
| tensor10 = rx.TensorStructInfo([n, m, 2], "int32", vdevice1) |
| |
| # obj |
| _check_lca(obj0, prim0, obj0) |
| _check_lca(obj0, prim1, obj0) |
| |
| # shape |
| _check_lca(shape0, tensor0, obj0) |
| _check_lca(shape0, shape1, shape0) |
| _check_lca(shape1, shape2, shape0) |
| _check_lca(shape1, shape3, shape0) |
| |
| _check_lca(shape2, shape3, shape2) |
| _check_lca(shape3, shape4, shape2) |
| _check_lca(shape4, rx.ShapeStructInfo([1, n, 3]), shape4) |
| |
| # tensor |
| _check_lca(tensor0, prim0, obj0) |
| _check_lca(tensor0, tensor1, rx.TensorStructInfo(ndim=-1, dtype=None)) |
| _check_lca(tensor0, tensor2, tensor0) |
| _check_lca(tensor0, tensor4, tensor0) |
| _check_lca(tensor0, tensor4, tensor0) |
| _check_lca(tensor1, tensor3, tensor1) |
| _check_lca(tensor3, tensor7, tensor3) |
| _check_lca(tensor3, tensor8, tensor3) |
| _check_lca(tensor1, tensor8, tensor1) |
| _check_lca(tensor6, tensor9, tensor6) |
| _check_lca(tensor6, tensor10, tensor6) |
| |
| _check_lca(tensor2, tensor4, tensor2) |
| _check_lca(tensor5, tensor6, rx.TensorStructInfo(ndim=3, dtype="int32")) |
| _check_lca(tensor4, tensor5, rx.TensorStructInfo(ndim=-1, dtype="int32")) |
| _check_lca(tensor4, rx.TensorStructInfo([n, m], dtype="int32"), tensor4) |
| |
| # tuple |
| t0 = rx.TupleStructInfo([obj0, tensor0]) |
| t1 = rx.TupleStructInfo([prim0, tensor4]) |
| t2 = rx.TupleStructInfo([obj0, tensor0, obj0]) |
| t3 = rx.TupleStructInfo([tensor0, obj0]) |
| |
| _check_lca(t0, t1, t0) |
| _check_lca(t0, t2, obj0) |
| _check_lca(t0, t3, rx.TupleStructInfo([obj0, obj0])) |
| |
| t5 = rx.TupleStructInfo([t0, t1]) |
| t6 = rx.TupleStructInfo([t1, t2]) |
| |
| _check_lca(t5, t6, rx.TupleStructInfo([t0, obj0])) |
| |
| t7 = rx.TupleStructInfo([]) |
| _check_lca(t7, rx.TupleStructInfo([]), t7) |
| |
| def fn_info_shape(c): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| x = rx.TensorStructInfo([c, n, m], "float32") |
| y = rx.TensorStructInfo([c, n, 1], "float32") |
| z = rx.TensorStructInfo([c, n], "float32") |
| return rx.FuncStructInfo([x, y], z) |
| |
| def fn_info_erased(): |
| x = rx.TensorStructInfo(ndim=3, dtype="float32") |
| y = rx.TensorStructInfo(ndim=3, dtype="float32") |
| z = rx.TensorStructInfo(ndim=2, dtype="float32") |
| return rx.FuncStructInfo([x, y], z) |
| |
| fopaque0 = lambda: rx.FuncStructInfo.opaque_func() |
| fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0) |
| fopaque2 = lambda: rx.FuncStructInfo.opaque_func( |
| ret=rx.TensorStructInfo(ndim=2, dtype="float32") |
| ) |
| |
| _check_lca(fn_info_shape(1), fn_info_shape(2), fn_info_erased()) |
| _check_lca(fn_info_shape(2), fn_info_shape(2), fn_info_shape(2)) |
| |
| _check_lca(fopaque0(), fopaque1(), fopaque0()) |
| _check_lca(fopaque0(), fn_info_shape(1), fopaque0()) |
| _check_lca(fopaque2(), fn_info_shape(1), fopaque2()) |
| |
| |
| def _generate_prim_test_cases(): |
| dtypes = [ |
| "bool", |
| "int8", |
| "uint8", |
| "int16", |
| "uint16", |
| "int32", |
| "uint32", |
| "int64", |
| "uint64", |
| "float16", |
| "float32", |
| "float64", |
| ] |
| |
| for dtype in dtypes: |
| # LCA of a PrimStructInfo with itself yields itself |
| yield (R.Prim(dtype), R.Prim(dtype), R.Prim(dtype)) |
| |
| # The LCA of two values, each statically known to be the same |
| # value, is known to have that value. |
| yield ( |
| R.Prim(value=tir.const(0, dtype)), |
| R.Prim(value=tir.const(0, dtype)), |
| R.Prim(value=tir.const(0, dtype)), |
| ) |
| |
| # The LCA of two values, each of which is statically known to |
| # have a different value, no longer knows the contained value. |
| yield ( |
| R.Prim(value=tir.const(0, dtype)), |
| R.Prim(value=tir.const(1, dtype)), |
| R.Prim(dtype=dtype), |
| ) |
| |
| # LCA of a known variable with itself yields itself |
| var_N = tir.Var("N", dtype) |
| yield (R.Prim(value=var_N), R.Prim(value=var_N), R.Prim(value=var_N)) |
| |
| # LCA of a known variable with a known static value is no |
| # longer known to have a specific value. |
| yield (R.Prim(value=var_N), R.Prim(value=tir.const(0, dtype)), R.Prim(dtype=dtype)) |
| yield (R.Prim(value=tir.const(0, dtype)), R.Prim(value=var_N), R.Prim(dtype=dtype)) |
| |
| var_M = tir.Var("M", dtype) |
| yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Prim(dtype=dtype)) |
| |
| for dtype_a in dtypes: |
| for dtype_b in dtypes: |
| if dtype_a != dtype_b: |
| # Unlike R.Tensor, R.Prim does not currently support a |
| # value with an unknown datatype. If the dtype |
| # differs between the two annotations, the next wider |
| # category is R.Object. |
| yield (R.Prim(dtype_a), R.Prim(dtype_b), R.Object) |
| |
| # Because the dtypes are different, even `R.Prim` containing |
| # the same value in different representations (e.g. |
| # `T.float32(0)` vs `T.float16(0)`) fall back to `R.Object`. |
| yield ( |
| R.Prim(value=tir.const(0, dtype_a)), |
| R.Prim(value=tir.const(0, dtype_b)), |
| R.Object, |
| ) |
| |
| # And the same is true for known variable values |
| var_N = tir.Var("N", dtype_a) |
| var_M = tir.Var("M", dtype_b) |
| yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Object) |
| |
| |
| @pytest.mark.parametrize("test_case", list(_generate_prim_test_cases())) |
| def test_prim_struct_info_lca(test_case): |
| def _normalize_sinfo(sinfo): |
| if isinstance(sinfo, tvm.relax.StructInfo): |
| return sinfo |
| elif isinstance(sinfo, tvm.script.parser.relax.entry.StructInfoProxy): |
| return sinfo.as_struct_info() |
| elif callable(sinfo): |
| return sinfo() |
| else: |
| raise TypeError(f"Cannot normalize {type(sinfo)} to StructInfo") |
| |
| lhs, rhs, expected = map(_normalize_sinfo, test_case) |
| |
| lca = rx.analysis.struct_info_lca(lhs, rhs) |
| assert tvm.ir.structural_equal( |
| lca, expected |
| ), f"Expected {lhs} and {rhs} to have LCA of {expected}, but instead found {lca}" |
| |
| |
| def _generate_tir_var_test_cases(): |
| n, m = tir.Var("n", "int64"), tir.Var("m", "int64") |
| shape0 = rx.ShapeStructInfo([1, n, 3]) |
| shape1 = rx.ShapeStructInfo([1, 2 * n, n, m]) |
| shape2 = rx.ShapeStructInfo([1, 2 * n, m]) |
| tensor0 = rx.TensorStructInfo([1, n, 3], "int32") |
| tensor1 = rx.TensorStructInfo([1, 2 * n, n, m], "int32") |
| tensor2 = rx.TensorStructInfo([1, 2 * n, m], "int32") |
| func = rx.FuncStructInfo( |
| [rx.TensorStructInfo([1, 2 * n, n, m], "int32")], rx.TensorStructInfo([1, n, 3], "int32") |
| ) |
| |
| yield shape0, [n], [n] |
| yield shape1, [n, m], [n, m] |
| yield shape2, [m], [n, m] |
| yield tensor0, [n], [n] |
| yield tensor1, [n, m], [n, m] |
| yield tensor2, [m], [n, m] |
| yield func, [n, m], [n, m] |
| |
| |
| tir_var_test_case = tvm.testing.parameter(*_generate_tir_var_test_cases()) |
| |
| |
| def test_tir_vars_in_struct_info(tir_var_test_case): |
| sinfo, _vars_definable, vars_used = tir_var_test_case |
| tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(sinfo), vars_used) |
| |
| |
| def test_definable_tir_vars_in_struct_info(tir_var_test_case): |
| sinfo, vars_definable, _vars_used = tir_var_test_case |
| tvm.ir.assert_structural_equal( |
| rx.analysis.definable_tir_vars_in_struct_info(sinfo), vars_definable |
| ) |
| |
| |
| def test_collect_symbolic_var_from_tensor_shape(): |
| n, m, k, q, p = ( |
| tir.Var("n", "int64"), |
| tir.Var("m", "int64"), |
| tir.Var("k", "int64"), |
| tir.Var("q", "int64"), |
| tir.Var("p", "int64"), |
| ) |
| bb = rx.BlockBuilder() |
| x = rx.Var("x", rx.TensorStructInfo([m, m + n], "float32")) |
| with bb.function("main", [x]): |
| v0 = bb.match_cast(x, rx.TensorStructInfo([m, k], "float32")) |
| v1 = bb.emit(rx.call_dps_packed("test", x, rx.TensorStructInfo([p, q], "float32"))) |
| bb.emit_func_output(rx.const(1)) |
| func = bb.get()["main"] |
| |
| defined_vars = set(rx.analysis.defined_symbolic_vars(func)) |
| free_vars = set(rx.analysis.free_symbolic_vars(func)) |
| assert defined_vars == {m, k} |
| assert free_vars == {n, p, q} |
| |
| |
| param_type = tvm.testing.parameter("shape_expr", "prim_value") |
| param_order = tvm.testing.parameter("definition_first", "usage_first") |
| |
| |
| def test_collect_symbolic_var_from_non_tensor_params(param_type, param_order): |
| tir_n = tir.Var("n", "int64") |
| tir_m = tir.Var("m", "int64") |
| |
| bb = rx.BlockBuilder() |
| arg = rx.Var("arg", rx.TensorStructInfo([tir_n * tir_m])) |
| |
| if param_type == "shape_expr": |
| extra_params = [ |
| rx.Var("shape_expr", rx.ShapeStructInfo([tir_n, tir_m])), |
| ] |
| elif param_type == "prim_value": |
| extra_params = [ |
| rx.Var("n", rx.PrimStructInfo(value=tir_n)), |
| rx.Var("m", rx.PrimStructInfo(value=tir_m)), |
| ] |
| else: |
| raise ValueError(f"Unknown param_type: {param_type}") |
| |
| if param_order == "definition_first": |
| params = [*extra_params, arg] |
| elif param_order == "usage_first": |
| params = [arg, *extra_params] |
| else: |
| raise ValueError(f"Unknown param_order: {param_order}") |
| |
| with bb.function("main", params=params): |
| out = rx.op.reshape(arg, [tir_n, tir_m]) |
| bb.emit_func_output(out) |
| func = bb.get()["main"] |
| |
| defined_vars = set(rx.analysis.defined_symbolic_vars(func)) |
| free_vars = set(rx.analysis.free_symbolic_vars(func)) |
| assert defined_vars == {tir_n, tir_m} |
| assert free_vars == set() |
| |
| |
| def test_collect_nonnegative_expressions(): |
| @R.function |
| def func( |
| A: R.Tensor([1024, "M", "N-2"]), |
| B: R.Tensor([128, "N", "M+2"]), |
| C: R.Shape(["M", "N"]), |
| D: R.Prim(value="N"), |
| ): |
| return R.tuple() |
| |
| M, N = list(func.params[2].struct_info.values) |
| |
| # Expressions are de-duplicated, in order of their first appearance |
| tvm.ir.assert_structural_equal( |
| rx.analysis.collect_non_negative_expressions(func.struct_info), |
| [M, N - 2, N, M + 2], |
| ) |
| |
| # Tensor shapes can imply that their shapes are non-negative |
| tvm.ir.assert_structural_equal( |
| rx.analysis.collect_non_negative_expressions(func.params[0].struct_info), |
| [M, N - 2], |
| ) |
| tvm.ir.assert_structural_equal( |
| rx.analysis.collect_non_negative_expressions(func.params[1].struct_info), |
| [N, M + 2], |
| ) |
| |
| # ShapeExpr values can imply that their contents are non-negative |
| tvm.ir.assert_structural_equal( |
| rx.analysis.collect_non_negative_expressions(func.params[2].struct_info), |
| [M, N], |
| ) |
| |
| # PrimValue instances may contain negative values, and do not |
| # imply that their contents are non-negative. |
| tvm.ir.assert_structural_equal( |
| rx.analysis.collect_non_negative_expressions(func.params[3].struct_info), |
| [], |
| ) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |