| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import pytest |
| import tvm |
| import tvm.testing |
| from tvm import relax, tir |
| from tvm import TVMError |
| from tvm.ir import Op, VDevice |
| from tvm.script import relax as R, tir as T |
| |
| |
| def test_op_correctness(): |
| x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) |
| assert relax.op.broadcast_to(x, (3, 3, 4, 5)).op == Op.get("relax.broadcast_to") |
| assert relax.op.concat([x]).op == Op.get("relax.concat") |
| assert relax.op.expand_dims(x, axis=[]).op == Op.get("relax.expand_dims") |
| assert relax.op.flatten(x).op == Op.get("relax.flatten") |
| assert relax.op.permute_dims(x).op == Op.get("relax.permute_dims") |
| assert relax.op.reshape(x, (4, 5, 3)).op == Op.get("relax.reshape") |
| assert relax.op.split(x, indices_or_sections=1).op == Op.get("relax.split") |
| assert relax.op.tile(x, (2, 2, 2)).op == Op.get("relax.tile") |
| assert relax.op.repeat(x, 2, 0).op == Op.get("relax.repeat") |
| assert relax.op.squeeze(x).op == Op.get("relax.squeeze") |
| assert relax.op.layout_transform(x, index_map=lambda a, b, c: (b, c, a)).op == Op.get( |
| "relax.layout_transform" |
| ) |
| assert relax.op.collapse_sum_to(x, (4, 5)).op == Op.get("relax.collapse_sum_to") |
| y = relax.Var("x", R.Tensor((4, 5), "float32")) |
| assert relax.op.collapse_sum_like(x, y).op == Op.get("relax.collapse_sum_like") |
| assert relax.op.cumsum(x, axis=1, dtype="int32").op == Op.get("relax.cumsum") |
| assert relax.op.einsum(x, subscripts="ii").op == Op.get("relax.einsum") |
| assert relax.op.flip(x, axis=1).op == Op.get("relax.flip") |
| assert relax.op.scatter_elements(x, x, x).op == Op.get("relax.scatter_elements") |
| assert relax.op.scatter_nd(x, x, x).op == Op.get("relax.scatter_nd") |
| |
| |
| def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): |
| ret = bb.normalize(call) |
| tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) |
| |
| |
| def test_reshape_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=4)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) |
| x4 = relax.Var("x", R.Tensor(ndim=4)) |
| x5 = relax.Var("x", R.Tensor()) |
| x6 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0)) |
| s0 = relax.Var("s", R.Shape((3, 8, 5))) |
| s1 = relax.Var("s", R.Shape(ndim=3)) |
| s2 = relax.Var("s", R.Shape()) |
| s3 = relax.ShapeExpr((3, 8, 5)) |
| |
| _check_inference( |
| bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.reshape(x6, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32", vdev0) |
| ) |
| _check_inference( |
| bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32") |
| ) |
| _check_inference(bb, relax.op.reshape(x0, (-1,)), relax.TensorStructInfo((120,), "float32")) |
| _check_inference( |
| bb, relax.op.reshape(x0, relax.ShapeExpr([-1])), relax.TensorStructInfo((120,), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.reshape(x3, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") |
| ) |
| _check_inference( |
| bb, relax.op.reshape(x3, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") |
| ) |
| _check_inference( |
| bb, relax.op.reshape(x4, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") |
| ) |
| _check_inference( |
| bb, relax.op.reshape(x5, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") |
| ) |
| # Remove Var from StructInfo when we can |
| _check_inference(bb, relax.op.reshape(x0, s0), relax.TensorStructInfo((3, 8, 5), "float32")) |
| _check_inference(bb, relax.op.reshape(x1, s0), relax.TensorStructInfo((3, 8, 5), "float32")) |
| _check_inference(bb, relax.op.reshape(x2, s0), relax.TensorStructInfo((3, 8, 5), "float32")) |
| _check_inference(bb, relax.op.reshape(x3, s0), relax.TensorStructInfo((3, 8, 5), dtype="")) |
| _check_inference(bb, relax.op.reshape(x4, s0), relax.TensorStructInfo((3, 8, 5), dtype="")) |
| _check_inference(bb, relax.op.reshape(x5, s0), relax.TensorStructInfo((3, 8, 5), dtype="")) |
| _check_inference(bb, relax.op.reshape(x0, s1), relax.TensorStructInfo(s1, "float32")) |
| _check_inference(bb, relax.op.reshape(x1, s1), relax.TensorStructInfo(s1, "float32")) |
| _check_inference(bb, relax.op.reshape(x2, s1), relax.TensorStructInfo(s1, "float32")) |
| _check_inference(bb, relax.op.reshape(x3, s1), relax.TensorStructInfo(s1, dtype="")) |
| _check_inference(bb, relax.op.reshape(x4, s1), relax.TensorStructInfo(s1, dtype="")) |
| _check_inference(bb, relax.op.reshape(x5, s1), relax.TensorStructInfo(s1, dtype="")) |
| _check_inference(bb, relax.op.reshape(x0, s2), relax.TensorStructInfo(s2, "float32")) |
| _check_inference(bb, relax.op.reshape(x1, s2), relax.TensorStructInfo(s2, "float32")) |
| _check_inference(bb, relax.op.reshape(x2, s2), relax.TensorStructInfo(s2, "float32")) |
| _check_inference(bb, relax.op.reshape(x3, s2), relax.TensorStructInfo(s2, dtype="")) |
| _check_inference(bb, relax.op.reshape(x4, s2), relax.TensorStructInfo(s2, dtype="")) |
| _check_inference(bb, relax.op.reshape(x5, s2), relax.TensorStructInfo(s2, dtype="")) |
| _check_inference(bb, relax.op.reshape(x0, s3), relax.TensorStructInfo(s3, "float32")) |
| _check_inference(bb, relax.op.reshape(x1, s3), relax.TensorStructInfo(s3, "float32")) |
| _check_inference(bb, relax.op.reshape(x2, s3), relax.TensorStructInfo(s3, "float32")) |
| _check_inference(bb, relax.op.reshape(x3, s3), relax.TensorStructInfo(s3, dtype="")) |
| _check_inference(bb, relax.op.reshape(x4, s3), relax.TensorStructInfo(s3, dtype="")) |
| _check_inference(bb, relax.op.reshape(x5, s3), relax.TensorStructInfo(s3, dtype="")) |
| |
| |
| def test_reshape_infer_struct_info_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| c = tir.Var("c", "int64") |
| d = tir.Var("d", "int64") |
| x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) |
| s0 = relax.Var("s", R.Shape((c, a, d, b))) |
| s1 = relax.Var("s", R.Shape()) |
| s2 = relax.ShapeExpr((c, a, d, b)) |
| |
| _check_inference( |
| bb, relax.op.reshape(x, (c, a, d, b)), relax.TensorStructInfo((c, a, d, b), "float32") |
| ) |
| _check_inference( |
| bb, |
| relax.op.reshape(x, (d, c, b, -1)), |
| relax.TensorStructInfo((d, c, b, a), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.reshape(x, (1, -1, 1)), |
| relax.TensorStructInfo((1, a * b * c * d, 1), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.reshape(x, (2, -1, a)), |
| relax.TensorStructInfo((2, tir.floordiv(b * c * d, 2), a), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.reshape(x, (c, -1, d, b)), |
| relax.TensorStructInfo((c, a, d, b), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.reshape(x, (c, a * d, b)), |
| relax.TensorStructInfo((c, a * d, b), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.reshape(x, (c, a * b * d, -1)), |
| relax.TensorStructInfo((c, a * b * d, 1), "float32"), |
| ) |
| # Remove Var from StructInfo when we can |
| _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo((c, a, d, b), "float32")) |
| _check_inference(bb, relax.op.reshape(x, s1), relax.TensorStructInfo(s1, "float32")) |
| _check_inference(bb, relax.op.reshape(x, s2), relax.TensorStructInfo(s2, "float32")) |
| |
| |
| def test_reshape_infer_struct_info_shape_var(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4, 5))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) |
| s2 = relax.Var("s", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| ns0 = relax.Var("ns", relax.ShapeStructInfo((3, 8, 5))) |
| ns1 = relax.Var("ns", relax.ShapeStructInfo()) |
| |
| _check_inference( |
| bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.reshape(x0, (2, 3, 0, 5)), relax.TensorStructInfo((2, 3, 4, 5), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.reshape(x0, (1, 3, 0, -1)), relax.TensorStructInfo((1, 3, 4, 10), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32") |
| ) |
| # Remove Var from StructInfo when we can |
| _check_inference(bb, relax.op.reshape(x0, ns0), relax.TensorStructInfo((3, 8, 5), "float32")) |
| _check_inference(bb, relax.op.reshape(x0, ns1), relax.TensorStructInfo(ns1, "float32")) |
| _check_inference( |
| bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") |
| ) |
| # Remove Var from StructInfo when we can |
| _check_inference(bb, relax.op.reshape(x1, ns0), relax.TensorStructInfo((3, 8, 5), "float32")) |
| _check_inference(bb, relax.op.reshape(x1, ns1), relax.TensorStructInfo(ns1, "float32")) |
| _check_inference( |
| bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") |
| ) |
| # Remove Var from StructInfo when we can |
| _check_inference(bb, relax.op.reshape(x2, ns0), relax.TensorStructInfo((3, 8, 5), "float32")) |
| _check_inference(bb, relax.op.reshape(x2, ns1), relax.TensorStructInfo(ns1, "float32")) |
| |
| |
| def test_reshape_infer_struct_info_more_input_dtype(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) |
| x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) |
| |
| _check_inference(bb, relax.op.reshape(x0, (120,)), relax.TensorStructInfo((120,), "float16")) |
| _check_inference(bb, relax.op.reshape(x1, (120,)), relax.TensorStructInfo((120,), "int8")) |
| |
| |
| def test_reshape_infer_struct_info_unequal_shape_prod(): |
| bb = relax.BlockBuilder() |
| s = relax.Var("s", relax.ShapeStructInfo((2, 3, 4, 5))) |
| x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) |
| ns = relax.Var("ns", relax.ShapeStructInfo((4, 4, 1, 5))) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x0, (4, 4, 1, 5))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x1, (4, 4, 1, 5))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x0, (4, 4, -1, 5))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x1, (4, 4, -1, 5))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x0, ns)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x1, ns)) |
| |
| |
| def test_reshape_infer_struct_info_inference_not_deducible(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) |
| s1 = relax.Var("s", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", R.Tensor("float32", ndim=4)) |
| x1 = relax.Var("x", R.Tensor("float32")) |
| x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x0, (2, 3, -1))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x1, (2, 3, -1))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x2, (2, 3, -1))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x3, (2, 3, -1))) |
| |
| |
| def test_reshape_new_shape_not_tuple(): |
| m = tir.Var("m", "int64") |
| x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) |
| |
| with pytest.raises(TypeError): |
| relax.op.reshape(x, 120) |
| with pytest.raises(TypeError): |
| relax.op.reshape(x, m) |
| |
| |
| def test_reshape_infer_struct_info_new_shape_not_integer(): |
| bb = relax.BlockBuilder() |
| x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x, (2.0, 3, 4, 5))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x, (2, 3, -1.0))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x, (2, 3, 4.0, -1))) |
| |
| |
| def test_reshape_infer_struct_info_multiple_dim_inference(): |
| bb = relax.BlockBuilder() |
| x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x, (2, -1, -1, 5))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x, (-1, -1, -1, -1))) |
| |
| |
| def test_reshape_infer_struct_info_non_positive_new_shape(): |
| bb = relax.BlockBuilder() |
| x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.reshape(x, (-2, -3, -4, -5))) |
| |
| |
| def test_reshape_infer_struct_info_wrong_input_type(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) |
| x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) |
| x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) |
| ns = relax.Var("ns", relax.TensorStructInfo((120,), "float32")) |
| pv = relax.Var("pv", relax.PrimStructInfo("int64")) |
| |
| with pytest.raises((TVMError, TypeError)): |
| bb.normalize(relax.op.reshape(x0, (2, 3, 4, 5))) |
| with pytest.raises((TVMError, TypeError)): |
| bb.normalize(relax.op.reshape(x1, (2, 3, 4, 5))) |
| with pytest.raises((TVMError, TypeError)): |
| bb.normalize(relax.op.reshape(x2, ns)) |
| with pytest.raises((TVMError, TypeError)): |
| bb.normalize(relax.op.reshape(x2, [pv])) |
| |
| |
| def test_permute_dims_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=4)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", R.Tensor((1, 2, 3, 4))) |
| x4 = relax.Var("x", R.Tensor(ndim=4)) |
| x5 = relax.Var("x", R.Tensor()) |
| x6 = relax.Var("x", R.Tensor((1,), "float32")) |
| x7 = relax.Var("x", R.Tensor((), "float32")) |
| x8 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32", vdev0)) |
| |
| _check_inference( |
| bb, relax.op.permute_dims(x0, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "float32") |
| ) |
| _check_inference( |
| bb, |
| relax.op.permute_dims(x8, [2, 3, 1, 0]), |
| relax.TensorStructInfo((3, 4, 2, 1), "float32", vdev0), |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x0, axes=None), relax.TensorStructInfo((4, 3, 2, 1), "float32") |
| ) |
| _check_inference( |
| bb, |
| relax.op.permute_dims(x0, [-2, -3, 3, -4]), |
| relax.TensorStructInfo((3, 2, 4, 1), "float32"), |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x1, [2, 3, 1, 0]), relax.TensorStructInfo(dtype="float32", ndim=4) |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x1, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x2, axes=None), relax.TensorStructInfo(dtype="float32") |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x3, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), dtype="") |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x3, axes=None), relax.TensorStructInfo((4, 3, 2, 1), dtype="") |
| ) |
| _check_inference( |
| bb, |
| relax.op.permute_dims(x3, [-2, -3, 3, -4]), |
| relax.TensorStructInfo((3, 2, 4, 1), dtype=""), |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x4, [2, 3, 1, 0]), relax.TensorStructInfo(dtype="", ndim=4) |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x4, axes=None), relax.TensorStructInfo(dtype="", ndim=4) |
| ) |
| _check_inference(bb, relax.op.permute_dims(x5, axes=None), relax.TensorStructInfo(dtype="")) |
| _check_inference( |
| bb, relax.op.permute_dims(x6, axes=None), relax.TensorStructInfo((1,), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x7, axes=None), relax.TensorStructInfo((), "float32") |
| ) |
| |
| |
| def test_permute_dims_infer_struct_info_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| c = tir.Var("c", "int64") |
| d = tir.Var("d", "int64") |
| x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) |
| |
| _check_inference( |
| bb, relax.op.permute_dims(x, [2, 3, 1, 0]), relax.TensorStructInfo((c, d, b, a), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x, axes=None), relax.TensorStructInfo((d, c, b, a), "float32") |
| ) |
| _check_inference( |
| bb, |
| relax.op.permute_dims(x, [-2, -3, 3, -4]), |
| relax.TensorStructInfo((c, b, d, a), "float32"), |
| ) |
| |
| |
| def test_permute_dims_infer_struct_info_shape_var(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((1, 2, 3, 4))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) |
| s2 = relax.Var("s", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| |
| _check_inference( |
| bb, relax.op.permute_dims(x0, [0, 1, 2, 3]), relax.TensorStructInfo(s0, "float32") |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x0, [-4, -3, -2, -1]), relax.TensorStructInfo(s0, "float32") |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x0, [2, 3, 0, 1]), relax.TensorStructInfo(dtype="float32", ndim=4) |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x0, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x1, [0, 1, 2, 3]), relax.TensorStructInfo(s1, "float32") |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x1, [2, 3, 0, 1]), relax.TensorStructInfo(dtype="float32", ndim=4) |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x1, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x2, axes=None), relax.TensorStructInfo(dtype="float32") |
| ) |
| |
| |
| def test_permute_dims_infer_struct_info_more_input_dtype(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float16")) |
| x1 = relax.Var("x", R.Tensor((1, 2, 3, 4), "int8")) |
| x2 = relax.Var("x", R.Tensor((1, 2, 3, 4), "int32")) |
| |
| _check_inference( |
| bb, relax.op.permute_dims(x0, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "float16") |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x1, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "int8") |
| ) |
| _check_inference( |
| bb, relax.op.permute_dims(x2, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "int32") |
| ) |
| |
| |
| def test_permute_dims_infer_struct_info_unknown_ndim_with_axes(): |
| bb = relax.BlockBuilder() |
| s = relax.Var("s", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", R.Tensor("float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x0, [2, 3, 1, 0])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x1, [2, 3, 1, 0])) |
| |
| |
| def test_permute_dims_infer_struct_info_wrong_number_axes(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((1, 2, 3, 4))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) |
| x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=4)) |
| x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x0, [0, 2, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x0, [1, 2, 4, 0, 3])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x1, [0, 2, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x1, [1, 2, 4, 0, 3])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x2, [0, 2, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x2, [1, 2, 4, 0, 3])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x3, [0, 2, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x3, [1, 2, 4, 0, 3])) |
| |
| |
| def test_permute_dims_infer_struct_info_axis_out_of_range(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=4)) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x0, [0, 3, 4, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x0, [0, -5, 1, 3])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x1, [0, 3, 4, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x1, [0, -5, 1, 3])) |
| |
| |
| def test_permute_dims_infer_struct_info_repetitive_axes(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=4)) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x0, [0, 2, 2, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x0, [0, 2, -2, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x1, [0, 2, 2, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x1, [0, 2, -2, 1])) |
| |
| |
| def test_permute_dims_infer_struct_info_wrong_input_type(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", relax.ShapeStructInfo((1, 2, 3, 4))) |
| x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((1, 2, 3, 4), "float32"))) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.permute_dims(x1)) |
| |
| |
| def test_expand_dims_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", R.Tensor((2, 3, 4))) |
| x4 = relax.Var("x", R.Tensor(ndim=3)) |
| x5 = relax.Var("x", R.Tensor()) |
| x6 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0)) |
| |
| _check_inference( |
| bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "float32") |
| ) |
| _check_inference( |
| bb, |
| relax.op.expand_dims(x6, [1, 3]), |
| relax.TensorStructInfo((2, 1, 3, 1, 4), "float32", vdev0), |
| ) |
| _check_inference( |
| bb, |
| relax.op.expand_dims(x0, [-1, 1, -6, 3, 5]), |
| relax.TensorStructInfo((2, 1, 1, 1, 3, 1, 4, 1), "float32"), |
| ) |
| _check_inference(bb, relax.op.expand_dims(x0, []), relax.TensorStructInfo((2, 3, 4), "float32")) |
| _check_inference( |
| bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) |
| ) |
| _check_inference( |
| bb, relax.op.expand_dims(x1, []), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference(bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo(dtype="float32")) |
| _check_inference(bb, relax.op.expand_dims(x2, []), relax.TensorStructInfo(dtype="float32")) |
| _check_inference( |
| bb, relax.op.expand_dims(x3, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), dtype="") |
| ) |
| _check_inference( |
| bb, |
| relax.op.expand_dims(x3, [-1, 1, -6, 3, 5]), |
| relax.TensorStructInfo((2, 1, 1, 1, 3, 1, 4, 1), dtype=""), |
| ) |
| _check_inference(bb, relax.op.expand_dims(x3, []), relax.TensorStructInfo((2, 3, 4), dtype="")) |
| _check_inference(bb, relax.op.expand_dims(x4, [1, 3]), relax.TensorStructInfo(dtype="", ndim=5)) |
| _check_inference(bb, relax.op.expand_dims(x4, []), relax.TensorStructInfo(dtype="", ndim=3)) |
| _check_inference(bb, relax.op.expand_dims(x5, [1, 3]), relax.TensorStructInfo(dtype="")) |
| _check_inference(bb, relax.op.expand_dims(x5, []), relax.TensorStructInfo(dtype="")) |
| |
| |
| def test_expand_dims_infer_struct_info_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| x = relax.Var("x", R.Tensor((a, 4, b), "float32")) |
| |
| _check_inference( |
| bb, relax.op.expand_dims(x, [1, 3]), relax.TensorStructInfo((a, 1, 4, 1, b), "float32") |
| ) |
| _check_inference( |
| bb, |
| relax.op.expand_dims(x, [-1, 1, -6, 3, 5]), |
| relax.TensorStructInfo((a, 1, 1, 1, 4, 1, b, 1), "float32"), |
| ) |
| _check_inference(bb, relax.op.expand_dims(x, []), relax.TensorStructInfo((a, 4, b), "float32")) |
| |
| |
| def test_expand_dims_infer_struct_info_shape_var(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) |
| s2 = relax.Var("s", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| |
| _check_inference( |
| bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) |
| ) |
| _check_inference(bb, relax.op.expand_dims(x0, []), relax.TensorStructInfo(s0, "float32")) |
| _check_inference( |
| bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) |
| ) |
| _check_inference(bb, relax.op.expand_dims(x1, []), relax.TensorStructInfo(s1, "float32")) |
| _check_inference(bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo(dtype="float32")) |
| _check_inference(bb, relax.op.expand_dims(x2, []), relax.TensorStructInfo(s2, "float32")) |
| |
| |
| def test_expand_dims_infer_struct_info_more_input_dtype(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) |
| x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) |
| x2 = relax.Var("x", R.Tensor((2, 3, 4), "int32")) |
| |
| _check_inference( |
| bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "float16") |
| ) |
| _check_inference( |
| bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "int8") |
| ) |
| _check_inference( |
| bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "int32") |
| ) |
| |
| |
| def test_expand_dims_infer_struct_info_axis_out_of_range(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", relax.TensorStructInfo(s0)) |
| x3 = relax.Var("x", relax.TensorStructInfo(s1)) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x0, [1, 5])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x0, [-6, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x1, [1, 5])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x1, [-6, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x2, [1, 5])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x2, [-6, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x3, [1, 5])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x3, [-6, 1])) |
| |
| |
| def test_expand_dims_infer_struct_info_repetitive_axes(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", relax.TensorStructInfo(s0)) |
| x3 = relax.Var("x", relax.TensorStructInfo(s1)) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x0, [1, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x0, [1, -4])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x1, [1, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x1, [1, -4])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x2, [1, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x2, [1, -4])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x3, [1, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x3, [1, -4])) |
| |
| |
| def test_expand_dims_infer_struct_info_wrong_input_type(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) |
| x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x0, axis=[])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.expand_dims(x1, axis=[])) |
| |
| |
| def test_layout_transform_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x = relax.Var("x", R.Tensor((10, 20, 30), "float32")) |
| x1 = relax.Var("x", R.Tensor((10, 20, 30), "float32", vdev0)) |
| |
| transpose_transform = lambda a, b, c: (a, c, b) |
| _check_inference( |
| bb, |
| relax.op.layout_transform(x, index_map=transpose_transform), |
| relax.TensorStructInfo((10, 30, 20), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.layout_transform(x1, index_map=transpose_transform), |
| relax.TensorStructInfo((10, 30, 20), "float32", vdev0), |
| ) |
| |
| tiling_transform = lambda a, b, c: (a, b // 2, c, b % 2) |
| _check_inference( |
| bb, |
| relax.op.layout_transform(x, index_map=tiling_transform), |
| relax.TensorStructInfo((10, 10, 30, 2), "float32"), |
| ) |
| |
| implicit_padding_transform = lambda a, b, c: (a, c, b // 3, b % 3) |
| _check_inference( |
| bb, |
| relax.op.layout_transform(x, index_map=implicit_padding_transform, pad_value=2), |
| relax.TensorStructInfo((10, 30, 7, 3), "float32"), |
| ) |
| |
| flatten_transform = lambda a, b, c: (a * 600 + b * 30 + c) |
| _check_inference( |
| bb, |
| relax.op.layout_transform(x, index_map=flatten_transform), |
| relax.TensorStructInfo((6000,), "float32"), |
| ) |
| |
| |
| def test_layout_transform_infer_struct_info_mismatch_dtype(): |
| bb = relax.BlockBuilder() |
| x = relax.Var("x", R.Tensor((10, 20, 30), "int32")) |
| |
| transpose_transform = lambda a, b, c: (a, c, b) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.layout_transform(x, index_map=transpose_transform, pad_value=2.2)) |
| |
| |
| def test_layout_transform_infer_struct_info_unknown_shape(): |
| bb = relax.BlockBuilder() |
| tiling_transform = lambda a, b: (a, b // 2, b % 2) |
| |
| x_unknown_shape = relax.Var("x", R.Tensor("float32", ndim=2)) |
| _check_inference( |
| bb, |
| relax.op.layout_transform(x_unknown_shape, index_map=tiling_transform), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| ) |
| |
| x_unknown_rank_dtype = relax.Var("x", R.Tensor()) |
| _check_inference( |
| bb, |
| relax.op.layout_transform(x_unknown_rank_dtype, index_map=tiling_transform), |
| relax.TensorStructInfo(dtype="", ndim=3), |
| ) |
| |
| |
| def test_layout_transform_infer_struct_info_symbolic_shape(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| x0 = relax.Var("x", R.Tensor((a, b), "float32")) |
| |
| tiling_transform = lambda a, b: (a, b // 3, b % 3) |
| _check_inference( |
| bb, |
| relax.op.layout_transform(x0, index_map=tiling_transform), |
| relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"), |
| ) |
| |
| |
| def test_layout_transform_infer_struct_info_shape_var(): |
| bb = relax.BlockBuilder() |
| |
| s = relax.Var("s", relax.ShapeStructInfo((30, 20))) |
| x = relax.Var("x", relax.TensorStructInfo(s, "float32")) |
| tiling_padding_transform = lambda a, b: (a, b // 3, b % 3) |
| _check_inference( |
| bb, |
| relax.op.layout_transform(x, index_map=tiling_padding_transform), |
| relax.TensorStructInfo((30, 7, 3), "float32"), |
| ) |
| |
| s_unknown_shape = relax.Var("s", relax.ShapeStructInfo(ndim=2)) |
| x_unknown_shape = relax.Var("x", relax.TensorStructInfo(s_unknown_shape, "float32")) |
| _check_inference( |
| bb, |
| relax.op.layout_transform(x_unknown_shape, index_map=tiling_padding_transform), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| ) |
| |
| s_unknown_rank = relax.Var("s", relax.ShapeStructInfo()) |
| x_unknown_rank = relax.Var("x", relax.TensorStructInfo(s_unknown_rank, "float32")) |
| _check_inference( |
| bb, |
| relax.op.layout_transform(x_unknown_rank, index_map=tiling_padding_transform), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| ) |
| |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| s_symbolic_shape = relax.Var("s", relax.ShapeStructInfo((a, b))) |
| x_symbolic_shape = relax.Var("x", relax.TensorStructInfo(s_symbolic_shape, "float32")) |
| _check_inference( |
| bb, |
| relax.op.layout_transform(x_symbolic_shape, index_map=tiling_padding_transform), |
| relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"), |
| ) |
| |
| |
| def test_layout_transform_infer_struct_info_invalid_index_map(): |
| bb = relax.BlockBuilder() |
| x = relax.Var("x", R.Tensor((10, 20, 30), "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.layout_transform(x, index_map=lambda a, b: (b, a))) |
| |
| |
| def test_squeeze_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=6)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4))) |
| x4 = relax.Var("x", R.Tensor(ndim=6)) |
| x5 = relax.Var("x", R.Tensor()) |
| x6 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32", vdev0)) |
| |
| _check_inference( |
| bb, relax.op.squeeze(x0, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.squeeze(x6, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), "float32", vdev0) |
| ) |
| _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo((2, 3, 4), "float32")) |
| _check_inference( |
| bb, relax.op.squeeze(x1, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) |
| ) |
| _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(dtype="float32")) |
| _check_inference(bb, relax.op.squeeze(x2, [1, 4]), relax.TensorStructInfo(dtype="float32")) |
| _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo(dtype="float32")) |
| _check_inference( |
| bb, relax.op.squeeze(x3, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), dtype="") |
| ) |
| _check_inference(bb, relax.op.squeeze(x3), relax.TensorStructInfo((2, 3, 4), dtype="")) |
| _check_inference(bb, relax.op.squeeze(x4, [1, 4]), relax.TensorStructInfo(dtype="", ndim=4)) |
| _check_inference(bb, relax.op.squeeze(x4), relax.TensorStructInfo(dtype="")) |
| _check_inference(bb, relax.op.squeeze(x5, [1, 4]), relax.TensorStructInfo(dtype="")) |
| _check_inference(bb, relax.op.squeeze(x5), relax.TensorStructInfo(dtype="")) |
| |
| |
| def test_squeeze_infer_struct_info_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| x0 = relax.Var("x", R.Tensor((a, 1, b), "float32")) |
| x1 = relax.Var("x", R.Tensor((a, 1, b))) |
| |
| _check_inference(bb, relax.op.squeeze(x0, [1]), relax.TensorStructInfo((a, b), "float32")) |
| _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo(dtype="float32")) |
| _check_inference(bb, relax.op.squeeze(x1, [1]), relax.TensorStructInfo((a, b), dtype="")) |
| _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(dtype="")) |
| |
| |
| def test_squeeze_infer_struct_info_shape_var(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) |
| s1 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) |
| s2 = relax.Var("s", relax.ShapeStructInfo((a, 1, b))) |
| s3 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) |
| s4 = relax.Var("s", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) |
| x4 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) |
| |
| _check_inference( |
| bb, relax.op.squeeze(x0, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) |
| ) |
| _check_inference(bb, relax.op.squeeze(x0, []), relax.TensorStructInfo(s0, "float32")) |
| _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo(dtype="float32")) |
| _check_inference(bb, relax.op.squeeze(x1, []), relax.TensorStructInfo(s1, "float32")) |
| _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(s1, dtype="float32")) |
| _check_inference(bb, relax.op.squeeze(x2, [1]), relax.TensorStructInfo(dtype="float32", ndim=2)) |
| _check_inference(bb, relax.op.squeeze(x2, []), relax.TensorStructInfo(s2, "float32")) |
| _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo(dtype="float32")) |
| _check_inference( |
| bb, relax.op.squeeze(x3, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) |
| ) |
| _check_inference(bb, relax.op.squeeze(x3, []), relax.TensorStructInfo(s3, "float32")) |
| _check_inference(bb, relax.op.squeeze(x3), relax.TensorStructInfo(dtype="float32")) |
| _check_inference(bb, relax.op.squeeze(x4, [1, 4]), relax.TensorStructInfo(dtype="float32")) |
| _check_inference(bb, relax.op.squeeze(x4, []), relax.TensorStructInfo(s4, "float32")) |
| _check_inference(bb, relax.op.squeeze(x4), relax.TensorStructInfo(dtype="float32")) |
| |
| |
| def test_squeeze_infer_struct_info_more_input_dtype(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float16")) |
| x1 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "int8")) |
| x2 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "int32")) |
| |
| _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo((2, 3, 4), "float16")) |
| _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo((2, 3, 4), "int8")) |
| _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo((2, 3, 4), "int32")) |
| |
| |
| def test_squeeze_infer_struct_info_axis_out_of_range(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) |
| x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=6)) |
| x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x0, [6])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x0, [-7])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x1, [6])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x1, [-7])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x2, [6])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x2, [-7])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x3, [6])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x3, [-7])) |
| |
| |
| def test_squeeze_infer_struct_info_repetitive_axes(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) |
| x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=6)) |
| x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x0, [3, -3])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x0, [1, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x1, [3, -3])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x1, [1, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x2, [3, -3])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x2, [1, 1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x3, [3, -3])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x3, [1, 1])) |
| |
| |
| def test_squeeze_infer_struct_info_axis_length_not_one(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) |
| s1 = relax.Var("s", relax.ShapeStructInfo((a, 3, 4))) |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor((a, 3, 4), "float32")) |
| x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x0, [0])) |
| _check_inference(bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo((3, 4), "float32")) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x2, [0])) |
| _check_inference(bb, relax.op.squeeze(x3, [0]), relax.TensorStructInfo(dtype="float32", ndim=2)) |
| |
| |
| def test_squeeze_infer_struct_info_wrong_input_type(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) |
| x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.squeeze(x1)) |
| |
| |
| def test_flatten_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) |
| x1 = relax.Var("x", R.Tensor((3,), "float32")) |
| x2 = relax.Var("x", R.Tensor((), "float32")) |
| x3 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x4 = relax.Var("x", R.Tensor("float32", ndim=1)) |
| x5 = relax.Var("x", R.Tensor("float32", ndim=0)) |
| x6 = relax.Var("x", R.Tensor("float32")) |
| x7 = relax.Var("x", R.Tensor((3, 4, 5))) |
| x8 = relax.Var("x", R.Tensor((3,))) |
| x9 = relax.Var("x", R.Tensor(())) |
| x10 = relax.Var("x", R.Tensor(ndim=3)) |
| x11 = relax.Var("x", R.Tensor(ndim=1)) |
| x12 = relax.Var("x", R.Tensor(ndim=0)) |
| x13 = relax.Var("x", R.Tensor()) |
| x14 = relax.Var("x", R.Tensor((3, 4, 5), "float32", vdev0)) |
| |
| _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((60,), "float32")) |
| _check_inference(bb, relax.op.flatten(x14), relax.TensorStructInfo((60,), "float32", vdev0)) |
| _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((3,), "float32")) |
| _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((1,), "float32")) |
| _check_inference(bb, relax.op.flatten(x3), relax.TensorStructInfo(dtype="float32", ndim=1)) |
| _check_inference(bb, relax.op.flatten(x4), relax.TensorStructInfo(dtype="float32", ndim=1)) |
| _check_inference(bb, relax.op.flatten(x5), relax.TensorStructInfo((1,), "float32")) |
| _check_inference(bb, relax.op.flatten(x6), relax.TensorStructInfo(dtype="float32", ndim=1)) |
| _check_inference(bb, relax.op.flatten(x7), relax.TensorStructInfo((60,), dtype="")) |
| _check_inference(bb, relax.op.flatten(x8), relax.TensorStructInfo((3,), dtype="")) |
| _check_inference(bb, relax.op.flatten(x9), relax.TensorStructInfo((1,), dtype="")) |
| _check_inference(bb, relax.op.flatten(x10), relax.TensorStructInfo(dtype="", ndim=1)) |
| _check_inference(bb, relax.op.flatten(x11), relax.TensorStructInfo(dtype="", ndim=1)) |
| _check_inference(bb, relax.op.flatten(x12), relax.TensorStructInfo((1,), dtype="")) |
| _check_inference(bb, relax.op.flatten(x13), relax.TensorStructInfo(dtype="", ndim=1)) |
| |
| |
| def test_flatten_infer_struct_info_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| x0 = relax.Var("x", R.Tensor((a, b), "float32")) |
| x1 = relax.Var("x", R.Tensor((a, b))) |
| |
| _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((a * b,), "float32")) |
| _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((a * b,), dtype="")) |
| |
| |
| def test_flatten_infer_struct_info_shape_var(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((3, 4, 5))) |
| s1 = relax.Var("s", relax.ShapeStructInfo((3,))) |
| s2 = relax.Var("s", relax.ShapeStructInfo(())) |
| s3 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) |
| s4 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) |
| s5 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) |
| s6 = relax.Var("s", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) |
| x4 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) |
| x5 = relax.Var("x", relax.TensorStructInfo(s5, "float32")) |
| x6 = relax.Var("x", relax.TensorStructInfo(s6, "float32")) |
| |
| _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo(dtype="float32", ndim=1)) |
| _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo(s1, "float32")) |
| _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((1,), "float32")) |
| _check_inference(bb, relax.op.flatten(x3), relax.TensorStructInfo(dtype="float32", ndim=1)) |
| _check_inference(bb, relax.op.flatten(x4), relax.TensorStructInfo(s4, "float32")) |
| _check_inference(bb, relax.op.flatten(x5), relax.TensorStructInfo((1,), "float32")) |
| _check_inference(bb, relax.op.flatten(x6), relax.TensorStructInfo(dtype="float32", ndim=1)) |
| |
| |
| def test_flatten_infer_struct_info_more_input_dtype(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((3, 4, 5), "float16")) |
| x1 = relax.Var("x", R.Tensor((3, 4, 5), "int8")) |
| x2 = relax.Var("x", R.Tensor((3, 4, 5), "int32")) |
| |
| _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((60,), "float16")) |
| _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((60,), "int8")) |
| _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((60,), "int32")) |
| |
| |
| def test_flatten_infer_struct_info_wrong_input_type(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", relax.ShapeStructInfo((3, 4, 5))) |
| x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((3, 4, 5), "float32"))) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.flatten(x0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.flatten(x1)) |
| |
| |
| def test_flatten_wrong_input_number(): |
| x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) |
| y = relax.Var("y", R.Tensor((2, 3, 4), "float32")) |
| |
| with pytest.raises(TypeError): |
| relax.op.flatten(x, y) |
| |
| |
| def test_concat_infer_struct_info_with_axis(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", R.Tensor((2, 3, 4))) |
| x4 = relax.Var("x", R.Tensor(ndim=3)) |
| x5 = relax.Var("x", R.Tensor()) |
| x6 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0)) |
| y0 = relax.Var("y", R.Tensor((2, 4, 4), "float32")) |
| y1 = relax.Var("y", R.Tensor("float32", ndim=3)) |
| y2 = relax.Var("y", R.Tensor("float32")) |
| y3 = relax.Var("y", R.Tensor((2, 4, 4))) |
| y4 = relax.Var("y", R.Tensor(ndim=3)) |
| y5 = relax.Var("y", R.Tensor()) |
| y6 = relax.Var("y", R.Tensor((2, 4, 4), "float32", vdev0)) |
| z0 = relax.Var("z", R.Tensor((2, 5, 4), "float32")) |
| z1 = relax.Var("z", R.Tensor("float32", ndim=3)) |
| z2 = relax.Var("z", R.Tensor("float32")) |
| z3 = relax.Var("z", R.Tensor((2, 5, 4))) |
| z4 = relax.Var("z", R.Tensor(ndim=3)) |
| z5 = relax.Var("z", R.Tensor()) |
| z6 = relax.Var("z", R.Tensor((2, 5, 4), "float32", vdev0)) |
| |
| _check_inference( |
| bb, relax.op.concat([x0, y0, z0], axis=1), relax.TensorStructInfo((2, 12, 4), "float32") |
| ) |
| _check_inference( |
| bb, |
| relax.op.concat([x6, y6, z6], axis=1), |
| relax.TensorStructInfo((2, 12, 4), "float32", vdev0), |
| ) |
| _check_inference( |
| bb, |
| relax.op.concat([x6, y0, z0], axis=1), |
| relax.TensorStructInfo((2, 12, 4), "float32", vdev0), |
| ) |
| _check_inference( |
| bb, relax.op.concat([x0, y0, z0], axis=-2), relax.TensorStructInfo((2, 12, 4), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.concat([x1, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x2, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x3, y0, z0], axis=1), relax.TensorStructInfo((2, 12, 4), dtype="") |
| ) |
| _check_inference( |
| bb, relax.op.concat([x3, y0, z0], axis=-2), relax.TensorStructInfo((2, 12, 4), dtype="") |
| ) |
| _check_inference( |
| bb, relax.op.concat([x4, y0, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x5, y0, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x1, y1, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x2, y1, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x3, y1, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x5, y1, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x2, y2, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x3, y2, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x5, y5, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x1, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x2, y2, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x3, y1, z1], axis=1), relax.TensorStructInfo(dtype="", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x2, y2, z2], axis=1), relax.TensorStructInfo(dtype="float32", ndim=-1) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x3, y2, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x4, y4, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x5, y5, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=-1) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x3, y3, z3], axis=1), relax.TensorStructInfo((2, 12, 4), dtype="") |
| ) |
| _check_inference( |
| bb, relax.op.concat([x3, y3, z3], axis=-2), relax.TensorStructInfo((2, 12, 4), dtype="") |
| ) |
| _check_inference( |
| bb, relax.op.concat([x4, y3, z3], axis=1), relax.TensorStructInfo(dtype="", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x5, y5, z3], axis=1), relax.TensorStructInfo(dtype="", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x4, y4, z4], axis=1), relax.TensorStructInfo(dtype="", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x5, y5, z4], axis=1), relax.TensorStructInfo(dtype="", ndim=3) |
| ) |
| _check_inference(bb, relax.op.concat([x5, y5, z5], axis=1), relax.TensorStructInfo(dtype="")) |
| _check_inference( |
| bb, |
| relax.op.concat(relax.Tuple([x0, y0, z0]), axis=1), |
| relax.TensorStructInfo((2, 12, 4), "float32"), |
| ) |
| |
| |
| def test_concat_infer_struct_info_with_axis_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a0 = tir.Var("a0", "int64") |
| a1 = tir.Var("a1", "int64") |
| b0 = tir.Var("b0", "int64") |
| b1 = tir.Var("b1", "int64") |
| b2 = tir.Var("b2", "int64") |
| c = tir.Var("c", "int64") |
| x0 = relax.Var("x", R.Tensor((a0, b0, c), "float32")) |
| x1 = relax.Var("x", R.Tensor((a1, b0, c), "float32")) |
| x2 = relax.Var("x", R.Tensor((a0, b0, c), "float32")) |
| y = relax.Var("y", R.Tensor((a0, b1, c), "float32")) |
| z = relax.Var("z", R.Tensor((a0, b2, c), "float32")) |
| |
| _check_inference( |
| bb, |
| relax.op.concat([x0, y, z], axis=1), |
| relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.concat([x0, y, z], axis=-2), |
| relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), |
| ) |
| _check_inference( |
| bb, relax.op.concat([x1, y, z], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference( |
| bb, |
| relax.op.concat(relax.Tuple([x0, y, z]), axis=1), |
| relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.concat(relax.Tuple([x0, x2]), axis=1), |
| relax.TensorStructInfo((a0, b0 * 2, c), "float32"), |
| ) |
| |
| |
| def test_concat_infer_struct_info_with_axis_shape_var(): |
| bb = relax.BlockBuilder() |
| a0 = tir.Var("a0", "int64") |
| a1 = tir.Var("a1", "int64") |
| b0 = tir.Var("b0", "int64") |
| b1 = tir.Var("b1", "int64") |
| b2 = tir.Var("b2", "int64") |
| c = tir.Var("c", "int64") |
| sx0 = relax.Var("sx", relax.ShapeStructInfo((2, 3, 4))) |
| sx1 = relax.Var("sx", relax.ShapeStructInfo((a0, b0, c))) |
| sx2 = relax.Var("sx", relax.ShapeStructInfo((a1, b0, c))) |
| sx3 = relax.Var("sx", relax.ShapeStructInfo(ndim=3)) |
| sx4 = relax.Var("sx", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) |
| x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(sx3, "float32")) |
| x4 = relax.Var("x", relax.TensorStructInfo(sx4, "float32")) |
| y0 = relax.Var("y", R.Tensor((2, 4, 4), "float32")) |
| y1 = relax.Var("y", R.Tensor((a0, b1, c), "float32")) |
| z0 = relax.Var("z", R.Tensor((2, 5, 4), "float32")) |
| z1 = relax.Var("z", R.Tensor((a0, b2, c), "float32")) |
| |
| _check_inference( |
| bb, relax.op.concat([x0, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x1, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x2, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x3, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x4, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference( |
| bb, |
| relax.op.concat(relax.Tuple([x0, y0, z0]), axis=1), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| ) |
| |
| |
| def test_concat_infer_struct_info_without_axis(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((3,), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=1)) |
| x2 = relax.Var("x", R.Tensor((3,))) |
| x3 = relax.Var("x", R.Tensor(ndim=1)) |
| y0 = relax.Var("y", R.Tensor((4,), "float32")) |
| y1 = relax.Var("y", R.Tensor("float32", ndim=1)) |
| z0 = relax.Var("z", R.Tensor((5,), "float32")) |
| z1 = relax.Var("z", R.Tensor("float32", ndim=1)) |
| |
| _check_inference( |
| bb, relax.op.concat([x0, y0, z0], axis=None), relax.TensorStructInfo((12,), "float32") |
| ) |
| _check_inference( |
| bb, |
| relax.op.concat([x1, y0, z0], axis=None), |
| relax.TensorStructInfo(dtype="float32", ndim=1), |
| ) |
| _check_inference( |
| bb, relax.op.concat([x2, y0, z0], axis=None), relax.TensorStructInfo((12,), dtype="") |
| ) |
| _check_inference( |
| bb, relax.op.concat([x3, y0, z0], axis=None), relax.TensorStructInfo(dtype="", ndim=1) |
| ) |
| _check_inference( |
| bb, |
| relax.op.concat([x1, y1, z0], axis=None), |
| relax.TensorStructInfo(dtype="float32", ndim=1), |
| ) |
| _check_inference( |
| bb, relax.op.concat([x2, y1, z0], axis=None), relax.TensorStructInfo(dtype="", ndim=1) |
| ) |
| _check_inference( |
| bb, |
| relax.op.concat([x1, y1, z1], axis=None), |
| relax.TensorStructInfo(dtype="float32", ndim=1), |
| ) |
| _check_inference( |
| bb, |
| relax.op.concat(relax.Tuple([x0, y0, z0]), axis=None), |
| relax.TensorStructInfo((12,), "float32"), |
| ) |
| |
| |
| def test_concat_infer_struct_info_without_axis_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a0 = tir.Var("a0", "int64") |
| a1 = tir.Var("a1", "int64") |
| x0 = relax.Var("x", R.Tensor((a0,), "float32")) |
| x1 = relax.Var("x", R.Tensor((a0,), "")) |
| y0 = relax.Var("y", R.Tensor((a1,), "float32")) |
| y1 = relax.Var("y", R.Tensor((a1,), "")) |
| |
| _check_inference( |
| bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo((a0 + a1,), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.concat([x0, y1], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") |
| ) |
| _check_inference( |
| bb, relax.op.concat([x1, y0], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") |
| ) |
| _check_inference( |
| bb, relax.op.concat([x1, y1], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") |
| ) |
| _check_inference( |
| bb, |
| relax.op.concat(relax.Tuple([x0, y0]), axis=None), |
| relax.TensorStructInfo((a0 + a1,), "float32"), |
| ) |
| |
| |
| def test_concat_infer_struct_info_without_axis_shape_var(): |
| bb = relax.BlockBuilder() |
| sx0 = relax.Var("sx", relax.ShapeStructInfo((3,))) |
| sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=1)) |
| sy0 = relax.Var("sy", relax.ShapeStructInfo((4,))) |
| x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) |
| y0 = relax.Var("y", relax.TensorStructInfo(sy0, "float32")) |
| |
| _check_inference( |
| bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x1, y0], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) |
| ) |
| _check_inference( |
| bb, |
| relax.op.concat(relax.Tuple([x0, y0]), axis=None), |
| relax.TensorStructInfo(dtype="float32", ndim=1), |
| ) |
| |
| |
| def test_concat_infer_struct_info_more_input_dtype(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((3,), "float16")) |
| y0 = relax.Var("y", R.Tensor((4,), "float16")) |
| x1 = relax.Var("x", R.Tensor((3,), "int8")) |
| y1 = relax.Var("y", R.Tensor((4,), "int8")) |
| x2 = relax.Var("x", R.Tensor((3,), "int32")) |
| y2 = relax.Var("y", R.Tensor((4,), "int32")) |
| |
| _check_inference( |
| bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo((7,), "float16") |
| ) |
| _check_inference(bb, relax.op.concat([x1, y1], axis=None), relax.TensorStructInfo((7,), "int8")) |
| _check_inference( |
| bb, relax.op.concat([x2, y2], axis=None), relax.TensorStructInfo((7,), "int32") |
| ) |
| |
| |
| def test_concat_infer_struct_info_tuple_var(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a0", "int64") |
| b0 = tir.Var("b0", "int64") |
| b1 = tir.Var("b1", "int64") |
| t0 = relax.Var( |
| "t", |
| relax.TupleStructInfo( |
| [relax.TensorStructInfo((a, b0), "float32"), relax.TensorStructInfo((a, b1), "float32")] |
| ), |
| ) |
| t1 = relax.Var( |
| "t", |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((a, b0), "float32"), |
| relax.TensorStructInfo(dtype="float32", ndim=2), |
| ] |
| ), |
| ) |
| t2 = relax.Var( |
| "t", |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype="float32"), |
| relax.TensorStructInfo(dtype="float32", ndim=2), |
| ] |
| ), |
| ) |
| t3 = relax.Var( |
| "t", |
| relax.TupleStructInfo( |
| [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="float32")] |
| ), |
| ) |
| t4 = relax.Var( |
| "t", |
| relax.TupleStructInfo( |
| [relax.TensorStructInfo((a, b0), "float32"), relax.TensorStructInfo((a, b1))] |
| ), |
| ) |
| t5 = relax.Var( |
| "t", |
| relax.TupleStructInfo( |
| [relax.TensorStructInfo((a, b0), dtype=""), relax.TensorStructInfo((a, b1), dtype="")] |
| ), |
| ) |
| t6 = relax.Var( |
| "t", |
| relax.TupleStructInfo( |
| [relax.TensorStructInfo(dtype="", ndim=2), relax.TensorStructInfo(dtype="")] |
| ), |
| ) |
| t7 = relax.Var( |
| "t", |
| relax.TupleStructInfo([relax.TensorStructInfo(dtype=""), relax.TensorStructInfo(dtype="")]), |
| ) |
| |
| _check_inference( |
| bb, relax.op.concat(t0, axis=1), relax.TensorStructInfo((a, b0 + b1), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.concat(t1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) |
| ) |
| _check_inference( |
| bb, relax.op.concat(t2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) |
| ) |
| _check_inference(bb, relax.op.concat(t3, axis=1), relax.TensorStructInfo(dtype="float32")) |
| _check_inference( |
| bb, relax.op.concat(t4, axis=1), relax.TensorStructInfo((a, b0 + b1), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.concat(t5, axis=1), relax.TensorStructInfo((a, b0 + b1), dtype="") |
| ) |
| _check_inference(bb, relax.op.concat(t6, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) |
| _check_inference(bb, relax.op.concat(t7, axis=1), relax.TensorStructInfo(dtype="")) |
| |
| |
| def test_concat_infer_struct_info_single_input_tensor(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| s0 = relax.Var("s", relax.ShapeStructInfo((3, a))) |
| s1 = relax.Var("s", relax.ShapeStructInfo((a,))) |
| s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) |
| s3 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) |
| s4 = relax.Var("s", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", R.Tensor((3, a), "float32")) |
| x1 = relax.Var("x", R.Tensor((a,), "float32")) |
| x2 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x3 = relax.Var("x", R.Tensor("float32", ndim=1)) |
| x4 = relax.Var("x", R.Tensor("float32")) |
| x5 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x6 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x7 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| x8 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) |
| x9 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) |
| |
| _check_inference(bb, relax.op.concat([x0], axis=1), relax.TensorStructInfo((3, a), "float32")) |
| _check_inference(bb, relax.op.concat([x1], axis=0), relax.TensorStructInfo((a,), "float32")) |
| _check_inference(bb, relax.op.concat([x1], axis=None), relax.TensorStructInfo((a,), "float32")) |
| _check_inference( |
| bb, relax.op.concat([x2], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x3], axis=0), relax.TensorStructInfo(dtype="float32", ndim=1) |
| ) |
| _check_inference( |
| bb, relax.op.concat([x3], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) |
| ) |
| _check_inference(bb, relax.op.concat([x4], axis=1), relax.TensorStructInfo(dtype="float32")) |
| _check_inference(bb, relax.op.concat([x5], axis=1), relax.TensorStructInfo(s0, dtype="float32")) |
| _check_inference(bb, relax.op.concat([x6], axis=0), relax.TensorStructInfo(s1, dtype="float32")) |
| _check_inference( |
| bb, relax.op.concat([x6], axis=None), relax.TensorStructInfo(s1, dtype="float32") |
| ) |
| _check_inference(bb, relax.op.concat([x7], axis=1), relax.TensorStructInfo(s2, dtype="float32")) |
| _check_inference(bb, relax.op.concat([x8], axis=0), relax.TensorStructInfo(s3, dtype="float32")) |
| _check_inference( |
| bb, relax.op.concat([x8], axis=None), relax.TensorStructInfo(s3, dtype="float32") |
| ) |
| _check_inference(bb, relax.op.concat([x9], axis=1), relax.TensorStructInfo(s4, dtype="float32")) |
| |
| |
| def test_concat_infer_struct_info_zero_rank_input_tensor(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo(())) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) |
| x0 = relax.Var("x", R.Tensor((), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=0)) |
| x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x0], axis=0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x1], axis=0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x2], axis=None)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x3], axis=None)) |
| |
| |
| def test_concat_infer_struct_info_no_input_tensor(): |
| bb = relax.BlockBuilder() |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([], axis=1)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([], axis=None)) |
| |
| |
| def test_concat_infer_struct_info_without_axis_but_tensor_not_one_dimensional(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((3, 4))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) |
| s2 = relax.Var("s", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", R.Tensor((3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=2)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x0], axis=None)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x1], axis=None)) |
| _check_inference(bb, relax.op.concat([x2], axis=None), relax.TensorStructInfo(dtype="float32")) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x3], axis=None)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x4], axis=None)) |
| _check_inference(bb, relax.op.concat([x5], axis=None), relax.TensorStructInfo(s2, "float32")) |
| |
| |
| def test_concat_infer_struct_info_inconsistent_dtype(): |
| bb = relax.BlockBuilder() |
| x = relax.Var("x", R.Tensor((3,))) |
| y = relax.Var("y", R.Tensor((4,), "float32")) |
| z = relax.Var("z", R.Tensor((5,), "int8")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x, y, z], axis=0)) |
| |
| |
| def test_concat_infer_struct_info_inconsistent_ndim(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((4, 5))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) |
| x = relax.Var("x", R.Tensor((3,), "float32")) |
| y0 = relax.Var("y", R.Tensor((4, 5), "float32")) |
| y1 = relax.Var("y", R.Tensor("float32", ndim=2)) |
| y2 = relax.Var("y", relax.TensorStructInfo(s0, "float32")) |
| y3 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) |
| z = relax.Var("z", R.Tensor((5,), "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x, y0, z], axis=0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x, y1, z], axis=0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x, y2, z], axis=0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x, y3, z], axis=0)) |
| |
| |
| def test_concat_infer_struct_info_axis_out_of_range(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((3,))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) |
| x0 = relax.Var("x", R.Tensor((3,), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=1)) |
| x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x0], axis=1)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x1], axis=1)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x2], axis=1)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x3], axis=1)) |
| |
| |
| def test_concat_infer_struct_info_unequal_shape(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| s0 = relax.Var("s", relax.ShapeStructInfo((3, 4))) |
| s1 = relax.Var("s", relax.ShapeStructInfo((3, a + 2))) |
| x0 = relax.Var("x", R.Tensor((3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor((3, a + 2), "float32")) |
| x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| y0 = relax.Var("y", R.Tensor((3, 3), "float32")) |
| y1 = relax.Var("y", R.Tensor((3, a), "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x0, y0])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x2, y0])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x1, y1])) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([x3, y1])) |
| |
| |
| def test_concat_infer_struct_info_input_not_tuple(): |
| bb = relax.BlockBuilder() |
| x = relax.Var("x", R.Tensor((3,), "float32")) |
| s = relax.Var("s", relax.ShapeStructInfo((3,))) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat(x)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat(s)) |
| |
| |
| def test_concat_infer_struct_info_input_tuple_field_not_tensor(): |
| bb = relax.BlockBuilder() |
| s = relax.Var("s", relax.ShapeStructInfo((3,))) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.concat([s])) |
| |
| |
| def test_split_infer_struct_info_by_indices(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", R.Tensor((2, 10, 4))) |
| x4 = relax.Var("x", R.Tensor(ndim=3)) |
| x5 = relax.Var("x", R.Tensor()) |
| x6 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0)) |
| |
| _check_inference( |
| bb, |
| relax.op.split(x0, [3, 7], axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((2, 3, 4), "float32"), |
| relax.TensorStructInfo((2, 4, 4), "float32"), |
| relax.TensorStructInfo((2, 3, 4), "float32"), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x6, [3, 7], axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((2, 3, 4), "float32", vdev0), |
| relax.TensorStructInfo((2, 4, 4), "float32", vdev0), |
| relax.TensorStructInfo((2, 3, 4), "float32", vdev0), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x0, [3, 7], axis=-2), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((2, 3, 4), "float32"), |
| relax.TensorStructInfo((2, 4, 4), "float32"), |
| relax.TensorStructInfo((2, 3, 4), "float32"), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x1, [3, 7], axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x2, [3, 7], axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype="float32"), |
| relax.TensorStructInfo(dtype="float32"), |
| relax.TensorStructInfo(dtype="float32"), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x3, [3, 7], axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((2, 3, 4), dtype=""), |
| relax.TensorStructInfo((2, 4, 4), dtype=""), |
| relax.TensorStructInfo((2, 3, 4), dtype=""), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x4, [3, 7], axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype="", ndim=3), |
| relax.TensorStructInfo(dtype="", ndim=3), |
| relax.TensorStructInfo(dtype="", ndim=3), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x5, [3, 7], axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype=""), |
| relax.TensorStructInfo(dtype=""), |
| relax.TensorStructInfo(dtype=""), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x0, [-2, 2, 6, 4, 8, 12, 9], axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((2, 0, 4), "float32"), |
| relax.TensorStructInfo((2, 2, 4), "float32"), |
| relax.TensorStructInfo((2, 4, 4), "float32"), |
| relax.TensorStructInfo((2, 0, 4), "float32"), |
| relax.TensorStructInfo((2, 4, 4), "float32"), |
| relax.TensorStructInfo((2, 2, 4), "float32"), |
| relax.TensorStructInfo((2, 0, 4), "float32"), |
| relax.TensorStructInfo((2, 1, 4), "float32"), |
| ] |
| ), |
| ) |
| |
| |
| def test_split_infer_struct_info_by_indices_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| x = relax.Var("x", R.Tensor((a, b), "float32")) |
| |
| _check_inference( |
| bb, |
| relax.op.split(x, [10, 20], axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo([a, T.max(T.min(10, b) - T.min(0, b), 0)], dtype="float32"), |
| relax.TensorStructInfo([a, T.max(T.min(20, b) - T.min(10, b), 0)], dtype="float32"), |
| relax.TensorStructInfo([a, T.max(b - 20, 0)], dtype="float32"), |
| ] |
| ), |
| ) |
| |
| |
| def test_split_infer_struct_info_by_indices_shape_var(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((2, 10, 4))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) |
| s2 = relax.Var("s", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| |
| _check_inference( |
| bb, |
| relax.op.split(x0, [3], axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x1, [3], axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x2, [3], axis=1), |
| relax.TupleStructInfo( |
| [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="float32")] |
| ), |
| ) |
| |
| |
| def test_split_infer_struct_info_by_n_section(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", R.Tensor((2, 10, 4))) |
| x4 = relax.Var("x", R.Tensor(ndim=3)) |
| x5 = relax.Var("x", R.Tensor()) |
| |
| _check_inference( |
| bb, |
| relax.op.split(x0, 3, axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((2, 4, 4), "float32"), |
| relax.TensorStructInfo((2, 4, 4), "float32"), |
| relax.TensorStructInfo((2, 2, 4), "float32"), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x0, 2, axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((2, 5, 4), "float32"), |
| relax.TensorStructInfo((2, 5, 4), "float32"), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x0, 3, axis=-2), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((2, 4, 4), "float32"), |
| relax.TensorStructInfo((2, 4, 4), "float32"), |
| relax.TensorStructInfo((2, 2, 4), "float32"), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x1, 3, axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x2, 3, axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype="float32"), |
| relax.TensorStructInfo(dtype="float32"), |
| relax.TensorStructInfo(dtype="float32"), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x3, 3, axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((2, 4, 4), dtype=""), |
| relax.TensorStructInfo((2, 4, 4), dtype=""), |
| relax.TensorStructInfo((2, 2, 4), dtype=""), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x4, 3, axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype="", ndim=3), |
| relax.TensorStructInfo(dtype="", ndim=3), |
| relax.TensorStructInfo(dtype="", ndim=3), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x5, 3, axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype=""), |
| relax.TensorStructInfo(dtype=""), |
| relax.TensorStructInfo(dtype=""), |
| ] |
| ), |
| ) |
| |
| |
| def test_split_infer_struct_info_by_n_section_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| x = relax.Var("x", R.Tensor((a, b), "float32")) |
| |
| _check_inference( |
| bb, |
| relax.op.split(x, 3, axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((a, (b + 2) // 3), "float32"), |
| relax.TensorStructInfo((a, (b + 2) // 3), "float32"), |
| relax.TensorStructInfo((a, b - (b + 2) // 3 * 2), "float32"), |
| ] |
| ), |
| ) |
| |
| |
| def test_split_infer_struct_info_by_n_section_shape_var(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((2, 10, 4))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) |
| s2 = relax.Var("s", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| |
| _check_inference( |
| bb, |
| relax.op.split(x0, 3, axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x1, 3, axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x2, 3, axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype="float32"), |
| relax.TensorStructInfo(dtype="float32"), |
| relax.TensorStructInfo(dtype="float32"), |
| ] |
| ), |
| ) |
| |
| |
| def test_split_infer_struct_info_more_input_dtype(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 10, 4), "float16")) |
| x1 = relax.Var("x", R.Tensor((2, 10, 4), "int8")) |
| |
| _check_inference( |
| bb, |
| relax.op.split(x0, [3, 7], axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((2, 3, 4), "float16"), |
| relax.TensorStructInfo((2, 4, 4), "float16"), |
| relax.TensorStructInfo((2, 3, 4), "float16"), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x1, [3, 7], axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((2, 3, 4), "int8"), |
| relax.TensorStructInfo((2, 4, 4), "int8"), |
| relax.TensorStructInfo((2, 3, 4), "int8"), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x0, 3, axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((2, 4, 4), "float16"), |
| relax.TensorStructInfo((2, 4, 4), "float16"), |
| relax.TensorStructInfo((2, 2, 4), "float16"), |
| ] |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x1, 3, axis=1), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((2, 4, 4), "int8"), |
| relax.TensorStructInfo((2, 4, 4), "int8"), |
| relax.TensorStructInfo((2, 2, 4), "int8"), |
| ] |
| ), |
| ) |
| |
| |
| def test_split_infer_struct_info_single_output(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| s0 = relax.Var("s", relax.ShapeStructInfo((a, b))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) |
| s2 = relax.Var("s", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", R.Tensor((a, b), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=2)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| |
| _check_inference( |
| bb, |
| relax.op.split(x0, [], axis=1), |
| relax.TensorStructInfo((a, b), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x1, [], axis=1), |
| relax.TensorStructInfo(dtype="float32", ndim=2), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x2, [], axis=1), |
| relax.TensorStructInfo(dtype="float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x3, [], axis=1), |
| relax.TensorStructInfo(s0, "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x4, [], axis=1), |
| relax.TensorStructInfo(s1, "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x5, [], axis=1), |
| relax.TensorStructInfo(s2, "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x0, 1, axis=1), |
| relax.TensorStructInfo((a, b), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x1, 1, axis=1), |
| relax.TensorStructInfo(dtype="float32", ndim=2), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x2, 1, axis=1), |
| relax.TensorStructInfo(dtype="float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x3, 1, axis=1), |
| relax.TensorStructInfo(s0, "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x4, 1, axis=1), |
| relax.TensorStructInfo(s1, "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x5, 1, axis=1), |
| relax.TensorStructInfo(s2, "float32"), |
| ) |
| |
| |
| def test_split_indices_or_sections_int64(): |
| x = relax.Var("x", R.Tensor((2, 10, 4), "float32")) |
| split0 = relax.op.split(x, [3, 6], axis=1) |
| split1 = relax.op.split(x, 4, axis=1) |
| |
| assert split0.attrs.indices_or_sections[0].dtype == "int64" |
| assert split0.attrs.indices_or_sections[1].dtype == "int64" |
| assert split1.attrs.indices_or_sections.dtype == "int64" |
| |
| |
| def test_split_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| n = tir.Var("n", "int64") |
| x = relax.Var("x", R.Tensor((16, 4))) |
| y = relax.Var("y", R.Tensor((16, 4), "float32")) |
| z = relax.Var("z", R.Tensor((n, 16))) |
| w = relax.Var("w", R.Tensor((n + 5, 16))) |
| |
| # All relax shape variables are non-negative. When a scope |
| # begins, any TIR variables that are used as shape variables are |
| # declared to be non-negative `tvm.arith.Analyzer`. Because |
| # `relax.op.split` clamps the indices to be within the bounds of |
| # the axis being split, simplifying with non-negative shape |
| # variables can result in much simpler shapes. |
| # |
| # For example, an axis of size `n`, split on the range from 2 to 5 |
| # has size `T.max(T.min(5, n + 5) - T.min(2, n + 5), 0)`. If it |
| # is known that `n >= 0`, then this simplifies down to `3`. |
| bb.begin_scope([x, y, z, w]) |
| |
| _check_inference( |
| bb, |
| relax.op.split(x, 1), |
| R.Tensor([16, 4]), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x, 2), |
| R.Tuple( |
| R.Tensor([8, 4]), |
| R.Tensor([8, 4]), |
| ), |
| ) |
| # Uneven splits are allowed, with the last split being smaller than the others. |
| _check_inference( |
| bb, |
| relax.op.split(x, 3), |
| R.Tuple( |
| R.Tensor([6, 4]), |
| R.Tensor([6, 4]), |
| R.Tensor([4, 4]), |
| ), |
| ) |
| |
| # Dtype of result is inherited from the tensor |
| _check_inference( |
| bb, |
| relax.op.split(y, 2), |
| R.Tuple( |
| R.Tensor([8, 4], "float32"), |
| R.Tensor([8, 4], "float32"), |
| ), |
| ) |
| |
| # Axis can be explicitly specified. Otherwise, defaults to axis=0. |
| _check_inference( |
| bb, relax.op.split(x, [2], axis=1), R.Tuple(R.Tensor([16, 2]), R.Tensor([16, 2])) |
| ) |
| |
| # Split points can be explicitly specified |
| _check_inference( |
| bb, |
| relax.op.split(x, [2]), |
| R.Tuple( |
| R.Tensor([2, 4]), |
| R.Tensor([14, 4]), |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(x, [2, 5]), |
| R.Tuple( |
| R.Tensor([2, 4]), |
| R.Tensor([3, 4]), |
| R.Tensor([11, 4]), |
| ), |
| ) |
| |
| # Splitting a dynamic axis is allowed, and propagates the shape to the output |
| _check_inference( |
| bb, |
| relax.op.split(z, 2), |
| R.Tuple( |
| R.Tensor([(n + 1) // 2, 16]), |
| R.Tensor([n - (n + 1) // 2, 16]), |
| ), |
| ) |
| _check_inference( |
| bb, |
| relax.op.split(z, 3), |
| R.Tuple( |
| R.Tensor([(n + 2) // 3, 16]), |
| R.Tensor([(n + 2) // 3, 16]), |
| R.Tensor([n - (n + 2) // 3 * 2, 16]), |
| ), |
| ) |
| |
| # Splitting a dynamic axis at specific indices is allowed. |
| _check_inference( |
| bb, |
| relax.op.split(w, [2, 5]), |
| R.Tuple( |
| R.Tensor((2, 16)), |
| R.Tensor((3, 16)), |
| R.Tensor((n, 16)), |
| ), |
| ) |
| |
| |
| def test_split_infer_struct_info_non_integer_indices(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("c", "int64") |
| b = tir.Var("d", "int64") |
| x = relax.Var("x", R.Tensor((3, 4), "float32")) |
| |
| with pytest.raises(TypeError): |
| bb.normalize(relax.op.split(x, [a, b], axis=1)) |
| |
| |
| def test_split_invalid_n_section(): |
| n = tir.Var("n", "int64") |
| x = relax.Var("x", R.Tensor((3, 4), "float32")) |
| |
| with pytest.raises((TVMError, TypeError)): |
| relax.op.split(x, 0, axis=1) |
| with pytest.raises((TVMError, TypeError)): |
| relax.op.split(x, -1, axis=1) |
| with pytest.raises((TVMError, TypeError)): |
| relax.op.split(x, n, axis=1) |
| |
| |
| def test_split_infer_struct_info_axis_out_of_range(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 3), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=2)) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.split(x0, [], axis=2)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.split(x0, [], axis=-3)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.split(x1, 1, axis=2)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.split(x1, 1, axis=-3)) |
| |
| |
| def test_split_infer_invalid_struct_info_indices(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 3), "float32")) |
| v = relax.Var("v", relax.PrimStructInfo("int64")) |
| |
| with pytest.raises((TVMError, TypeError)): |
| bb.normalize(relax.op.split(x0, [v], axis=1)) |
| with pytest.raises((TVMError, TypeError)): |
| bb.normalize(relax.op.split(x0, v, axis=1)) |
| |
| |
| def test_split_infer_struct_info_wrong_input_type(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) |
| x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.split(x0, 1, axis=1)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.split(x1, 1, axis=1)) |
| |
| |
| def test_broadcast_to_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x0 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", R.Tensor((2, 1, 3))) |
| x4 = relax.Var("x", R.Tensor(ndim=3)) |
| x5 = relax.Var("x", R.Tensor()) |
| x6 = relax.Var("x", R.Tensor((2, 1, 3), "float32", vdev0)) |
| |
| _check_inference( |
| bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") |
| ) |
| _check_inference( |
| bb, |
| relax.op.broadcast_to(x6, (4, 2, 5, 3)), |
| relax.TensorStructInfo((4, 2, 5, 3), "float32", vdev0), |
| ) |
| _check_inference( |
| bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.broadcast_to(x3, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") |
| ) |
| _check_inference( |
| bb, relax.op.broadcast_to(x4, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") |
| ) |
| _check_inference( |
| bb, relax.op.broadcast_to(x5, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") |
| ) |
| |
| |
| def test_broadcast_to_infer_struct_info_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| c = tir.Var("c", "int64") |
| d = tir.Var("d", "int64") |
| x0 = relax.Var("x", R.Tensor((b, 1, 1, d), "float32")) |
| x1 = relax.Var("x", R.Tensor((b, 1, 1, d))) |
| |
| _check_inference( |
| bb, |
| relax.op.broadcast_to(x0, (a, b, 1, c, d)), |
| relax.TensorStructInfo((a, b, 1, c, d), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.broadcast_to(x1, (a, b, 1, c, d)), |
| relax.TensorStructInfo((a, b, 1, c, d), dtype=""), |
| ) |
| |
| |
| def test_broadcast_to_infer_struct_info_shape_var(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) |
| s2 = relax.Var("s", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| |
| _check_inference( |
| bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") |
| ) |
| |
| |
| def test_broadcast_to_infer_struct_info_tgt_shape_var(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| c = tir.Var("c", "int64") |
| d = tir.Var("d", "int64") |
| s0 = relax.Var("s", relax.ShapeStructInfo((b, 1, 1, d))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) |
| s2 = relax.Var("s", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", R.Tensor((b, 1, 1, d), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=4)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| stgt0 = relax.Var("stgt", relax.ShapeStructInfo((a, b, 1, c, d))) |
| stgt1 = relax.Var("stgt", relax.ShapeStructInfo(ndim=5)) |
| stgt2 = relax.Var("stgt", relax.ShapeStructInfo()) |
| |
| _check_inference(bb, relax.op.broadcast_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x2, stgt0), relax.TensorStructInfo(stgt0, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x3, stgt0), relax.TensorStructInfo(stgt0, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x4, stgt0), relax.TensorStructInfo(stgt0, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x5, stgt0), relax.TensorStructInfo(stgt0, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x2, stgt1), relax.TensorStructInfo(stgt1, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x3, stgt1), relax.TensorStructInfo(stgt1, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x4, stgt1), relax.TensorStructInfo(stgt1, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x5, stgt1), relax.TensorStructInfo(stgt1, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x2, stgt2), relax.TensorStructInfo(stgt2, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x3, stgt2), relax.TensorStructInfo(stgt2, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x4, stgt2), relax.TensorStructInfo(stgt2, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x5, stgt2), relax.TensorStructInfo(stgt2, "float32")) |
| |
| |
| def test_broadcast_to_infer_struct_info_more_input_dtype(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 1, 3), "float16")) |
| x1 = relax.Var("x", R.Tensor((2, 1, 3), "int8")) |
| x2 = relax.Var("x", R.Tensor((2, 1, 3), "int32")) |
| |
| _check_inference( |
| bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float16") |
| ) |
| _check_inference( |
| bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "int8") |
| ) |
| _check_inference( |
| bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "int32") |
| ) |
| |
| |
| def test_broadcast_to_infer_struct_info_tgt_ndim_less_than_old_ndim(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s", relax.ShapeStructInfo((2, 1))) |
| s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) |
| x0 = relax.Var("x", R.Tensor((2, 1), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=2)) |
| x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| stgt0 = relax.Var("stgt", relax.ShapeStructInfo((2,))) |
| stgt1 = relax.Var("stgt", relax.ShapeStructInfo(ndim=1)) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x0, (2,))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x0, stgt0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x0, stgt1)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x1, (2,))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x1, stgt0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x1, stgt1)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x2, (2,))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x2, stgt0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x2, stgt1)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x3, (2,))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x3, stgt0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x3, stgt1)) |
| |
| |
| def test_broadcast_to_infer_struct_info_not_broadcastable_static(): |
| bb = relax.BlockBuilder() |
| s = relax.Var("s", relax.ShapeStructInfo((2, 1, 3))) |
| x0 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) |
| stgt = relax.Var("stgt", relax.ShapeStructInfo((2, 1, 6))) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x0, (2, 1, 6))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x0, stgt)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x1, (2, 1, 6))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x1, stgt)) |
| |
| |
| def test_broadcast_to_infer_struct_info_not_broadcastable_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| s = relax.Var("s", relax.ShapeStructInfo((2, a))) |
| x0 = relax.Var("x", R.Tensor((2, a), "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) |
| stgt0 = relax.Var("stgt", relax.ShapeStructInfo((2, b))) |
| stgt1 = relax.Var("stgt", relax.ShapeStructInfo((2, 1))) |
| stgt2 = relax.Var("stgt", relax.ShapeStructInfo((b, a))) |
| |
| _check_inference( |
| bb, relax.op.broadcast_to(x0, (2, b)), relax.TensorStructInfo((2, b), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.broadcast_to(x0, (2, 1)), relax.TensorStructInfo((2, 1), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.broadcast_to(x0, (b, a)), relax.TensorStructInfo((b, a), "float32") |
| ) |
| _check_inference(bb, relax.op.broadcast_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32")) |
| _check_inference( |
| bb, relax.op.broadcast_to(x1, (2, b)), relax.TensorStructInfo((2, b), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.broadcast_to(x1, (2, 1)), relax.TensorStructInfo((2, 1), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.broadcast_to(x1, (b, a)), relax.TensorStructInfo((b, a), "float32") |
| ) |
| _check_inference(bb, relax.op.broadcast_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32")) |
| _check_inference(bb, relax.op.broadcast_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32")) |
| |
| |
| def test_broadcast_to_infer_struct_info_wrong_input_type(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", relax.ShapeStructInfo((2, 1, 3))) |
| x1 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) |
| stgt = relax.Var("stgt", relax.TensorStructInfo((4, 2, 5, 3), dtype="")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x0, (4, 2, 5, 3))) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.broadcast_to(x1, stgt)) |
| |
| |
| def test_collapse_sum_like_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0)) |
| x4 = relax.Var("x", R.Tensor(ndim=3)) |
| x5 = relax.Var("x", R.Tensor()) |
| y0 = relax.Var("y", R.Tensor((3, 4), "float32")) |
| y1 = relax.Var("y", R.Tensor("float32", ndim=2)) |
| y2 = relax.Var("y", R.Tensor("float32")) |
| y3 = relax.Var("y", R.Tensor((3, 4))) |
| y4 = relax.Var("y", R.Tensor(ndim=2)) |
| y5 = relax.Var("y", R.Tensor((1, 4))) |
| y6 = relax.Var("y", R.Tensor((3, 4), "float32", vdev0)) |
| |
| _check_inference( |
| bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_like(x3, y6), relax.TensorStructInfo((3, 4), "float32", vdev0) |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo(dtype="float32", ndim=2) |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_like(x0, y1), relax.TensorStructInfo(dtype="float32", ndim=2) |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_like(x0, y2), relax.TensorStructInfo(dtype="float32", ndim=-1) |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_like(x0, y3), relax.TensorStructInfo((3, 4), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_like(x2, y0), relax.TensorStructInfo((3, 4), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_like(x2, y4), relax.TensorStructInfo(dtype="float32", ndim=2) |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_like(x4, y1), relax.TensorStructInfo(dtype="", ndim=2) |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_like(x5, y3), relax.TensorStructInfo((3, 4), dtype="") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_like(x0, y5), relax.TensorStructInfo((1, 4), "float32") |
| ) |
| |
| |
| def test_collapse_sum_like_infer_struct_info_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| x0 = relax.Var("x", R.Tensor((3, 4, a), "float32")) |
| y0 = relax.Var("y", R.Tensor((4, a), "float32")) |
| x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32")) |
| y1 = relax.Var("x", R.Tensor((1, a + b), "float32")) |
| |
| _check_inference( |
| bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((4, a), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo((1, a + b), "float32") |
| ) |
| |
| |
| def test_collapse_sum_like_infer_struct_info_shape_var(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4))) |
| s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) |
| s2 = relax.Var("s2", relax.ShapeStructInfo()) |
| s3 = relax.Var("s3", relax.ShapeStructInfo((3, 4))) |
| s4 = relax.Var("s4", relax.ShapeStructInfo(ndim=2)) |
| s5 = relax.Var("s5", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| y0 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) |
| y1 = relax.Var("y", relax.TensorStructInfo(s4, "float32")) |
| y2 = relax.Var("y", relax.TensorStructInfo(s5, "float32")) |
| |
| _check_inference(bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo(s3, "float32")) |
| _check_inference(bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo(s4, "float32")) |
| _check_inference(bb, relax.op.collapse_sum_like(x2, y2), relax.TensorStructInfo(s5, "float32")) |
| |
| |
| def test_collapse_sum_like_infer_struct_info_more_input_dtype(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) |
| x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) |
| y0 = relax.Var("y", R.Tensor((3, 4), "float16")) |
| y1 = relax.Var("y", R.Tensor((3, 4), "int8")) |
| |
| _check_inference( |
| bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4), "float16") |
| ) |
| _check_inference(bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo((3, 4), "int8")) |
| |
| |
| def test_collapse_sum_like_infer_struct_info_wrong_input_type(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) |
| x1 = relax.Var("x", relax.ShapeStructInfo((4, 5))) |
| x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.collapse_sum_like(x0, x1)) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.collapse_sum_like(x2, x0)) |
| |
| |
| def test_collapse_sum_like_infer_struct_info_shape_mismatch(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) |
| y0 = relax.Var("y", R.Tensor((3, 6, 5), "float32")) |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| x1 = relax.Var("z", R.Tensor((3, a, 5), "float32")) |
| y1 = relax.Var("w", R.Tensor((3, b, 5), "float32")) |
| |
| s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5))) |
| s1 = relax.Var("s1", relax.ShapeStructInfo((3, 6, 5))) |
| x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| y2 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) |
| |
| s2 = relax.Var("s2", relax.ShapeStructInfo((3, a, 5))) |
| s3 = relax.Var("s3", relax.ShapeStructInfo((3, b, 5))) |
| x3 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| y3 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.collapse_sum_like(x0, y0)) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.collapse_sum_like(x1, y1)) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.collapse_sum_like(x2, y2)) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.collapse_sum_like(x3, y3)) |
| |
| |
| def test_collapse_sum_to_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", R.Tensor((2, 3, 4))) |
| x4 = relax.Var("x", R.Tensor(ndim=3)) |
| x5 = relax.Var("x", R.Tensor()) |
| |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x2, (3, 4)), relax.TensorStructInfo((3, 4), "float32") |
| ) |
| _check_inference(bb, relax.op.collapse_sum_to(x3, (3, 4)), relax.TensorStructInfo((3, 4), "")) |
| _check_inference(bb, relax.op.collapse_sum_to(x4, (3, 4)), relax.TensorStructInfo((3, 4), "")) |
| _check_inference(bb, relax.op.collapse_sum_to(x5, (3, 4)), relax.TensorStructInfo((3, 4), "")) |
| |
| |
| def test_collapse_sum_to_infer_struct_info_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| x0 = relax.Var("x", R.Tensor((3, 4, a), "float32")) |
| x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32")) |
| |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x0, (4, a)), relax.TensorStructInfo((4, a), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x1, (1, a + b)), relax.TensorStructInfo((1, a + b), "float32") |
| ) |
| |
| |
| def test_collapse_sum_to_infer_struct_info_shape_var(): |
| bb = relax.BlockBuilder() |
| s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4))) |
| s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) |
| s2 = relax.Var("s2", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") |
| ) |
| |
| |
| def test_collapse_sum_to_infer_struct_info_more_input_dtype(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) |
| x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) |
| |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float16") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "int8") |
| ) |
| |
| |
| def test_collapse_sum_to_infer_struct_info_wrong_input_type(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) |
| x1 = relax.Var("x", relax.ShapeStructInfo((4, 5))) |
| x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.collapse_sum_to(x0, x0)) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.collapse_sum_to(x0, x2)) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.collapse_sum_to(x1, x1)) |
| |
| |
| def test_collapse_sum_to_infer_struct_info_shape_mismatch(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| x1 = relax.Var("x", R.Tensor((3, a, 5), "float32")) |
| |
| s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5))) |
| x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| |
| s1 = relax.Var("s1", relax.ShapeStructInfo((3, a, 5))) |
| x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.collapse_sum_to(x0, (4, 4, 5))) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.collapse_sum_to(x1, (3, b, 5))) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.collapse_sum_to(x2, (4, 4, 5))) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.collapse_sum_to(x3, (3, b, 5))) |
| |
| |
| def test_collapse_sum_to_infer_struct_info_struct_info_tgt_shape_var(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| c = tir.Var("c", "int64") |
| d = tir.Var("d", "int64") |
| s0 = relax.Var("s0", relax.ShapeStructInfo((3, a, b))) |
| s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) |
| s2 = relax.Var("s2", relax.ShapeStructInfo()) |
| x0 = relax.Var("x", R.Tensor((3, a, b), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("")) |
| x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) |
| x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) |
| x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) |
| stgt0 = relax.Var("stgt0", relax.ShapeStructInfo((a, b))) |
| stgt1 = relax.Var("stgt1", relax.ShapeStructInfo(ndim=2)) |
| stgt2 = relax.Var("stgt2", relax.ShapeStructInfo()) |
| |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32") |
| ) |
| _check_inference(bb, relax.op.collapse_sum_to(x2, stgt0), relax.TensorStructInfo(stgt0, "")) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x3, stgt0), relax.TensorStructInfo(stgt0, "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x4, stgt0), relax.TensorStructInfo(stgt0, "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x5, stgt0), relax.TensorStructInfo(stgt0, "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32") |
| ) |
| _check_inference(bb, relax.op.collapse_sum_to(x2, stgt1), relax.TensorStructInfo(stgt1, "")) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x3, stgt1), relax.TensorStructInfo(stgt1, "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x4, stgt1), relax.TensorStructInfo(stgt1, "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x5, stgt1), relax.TensorStructInfo(stgt1, "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32") |
| ) |
| _check_inference(bb, relax.op.collapse_sum_to(x2, stgt2), relax.TensorStructInfo(stgt2, "")) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x3, stgt2), relax.TensorStructInfo(stgt2, "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x4, stgt2), relax.TensorStructInfo(stgt2, "float32") |
| ) |
| _check_inference( |
| bb, relax.op.collapse_sum_to(x5, stgt2), relax.TensorStructInfo(stgt2, "float32") |
| ) |
| |
| |
| def test_repeat_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", R.Tensor((2, 10, 4))) |
| x4 = relax.Var("x", R.Tensor(ndim=3)) |
| x5 = relax.Var("x", R.Tensor()) |
| x6 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0)) |
| |
| _check_inference( |
| bb, |
| relax.op.repeat(x0, 2, axis=0), |
| relax.TensorStructInfo((4, 10, 4), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.repeat(x6, 2, axis=0), |
| relax.TensorStructInfo((4, 10, 4), "float32", vdev0), |
| ) |
| _check_inference( |
| bb, |
| relax.op.repeat(x0, 2, axis=-2), |
| relax.TensorStructInfo((2, 20, 4), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.repeat(x0, 2), |
| relax.TensorStructInfo((160,), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.repeat(x1, 2, axis=0), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| ) |
| _check_inference( |
| bb, |
| relax.op.repeat(x1, 2), |
| relax.TensorStructInfo(dtype="float32", ndim=1), |
| ) |
| _check_inference(bb, relax.op.repeat(x2, 2, axis=0), relax.TensorStructInfo(dtype="float32")) |
| _check_inference(bb, relax.op.repeat(x2, 2), relax.TensorStructInfo(dtype="float32", ndim=1)) |
| _check_inference( |
| bb, |
| relax.op.repeat(x3, 2, axis=0), |
| relax.TensorStructInfo((4, 10, 4), dtype=""), |
| ) |
| _check_inference(bb, relax.op.repeat(x4, 2, axis=0), relax.TensorStructInfo(dtype="", ndim=3)) |
| _check_inference(bb, relax.op.repeat(x5, 2, axis=0), relax.TensorStructInfo(dtype="")) |
| |
| |
| def test_repeat_infer_struct_info_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| c = tir.Var("c", "int64") |
| x = relax.Var("x", R.Tensor((a, b, c), "float32")) |
| |
| _check_inference(bb, relax.op.repeat(x, 2, 0), relax.TensorStructInfo((a * 2, b, c), "float32")) |
| _check_inference( |
| bb, |
| relax.op.repeat(x, 2, -1), |
| relax.TensorStructInfo((a, b, c * 2), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.repeat(x, 2), |
| relax.TensorStructInfo((a * b * c * 2,), "float32"), |
| ) |
| |
| |
| def test_repeat_infer_struct_info_more_input_dtype(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) |
| x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) |
| |
| _check_inference(bb, relax.op.repeat(x0, 2, 0), relax.TensorStructInfo((4, 3, 4), "float16")) |
| _check_inference(bb, relax.op.repeat(x1, 2, 0), relax.TensorStructInfo((4, 3, 4), "int8")) |
| |
| |
| def test_repeat_infer_struct_info_axis_out_of_range(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.repeat(x0, 2, 3)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.repeat(x0, 2, -4)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.repeat(x1, 2, 3)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.repeat(x1, 2, -4)) |
| # okay |
| bb.normalize(relax.op.repeat(x2, 2, 3)) |
| bb.normalize(relax.op.repeat(x2, 2, -4)) |
| |
| |
| def test_repeat_return_data_sinfo(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| |
| _check_inference(bb, relax.op.repeat(x0, 1, 0), x0.struct_info) |
| _check_inference(bb, relax.op.repeat(x0, 1, -1), x0.struct_info) |
| _check_inference(bb, relax.op.repeat(x1, 1, 0), x1.struct_info) |
| _check_inference(bb, relax.op.repeat(x2, 1, 0), x2.struct_info) |
| |
| |
| def test_repeat_infer_struct_info_wrong_input_type(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) |
| x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) |
| x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) |
| r1 = tir.Var("r", "float32") |
| r2 = tir.StringImm("abc") |
| |
| with pytest.raises((TypeError, TVMError)): |
| bb.normalize(relax.op.repeat(x0, 2)) |
| with pytest.raises((TypeError, TVMError)): |
| bb.normalize(relax.op.repeat(x1, 2)) |
| with pytest.raises((TypeError, TVMError)): |
| bb.normalize(relax.op.repeat(x2, 1.5)) |
| with pytest.raises((TypeError, TVMError)): |
| bb.normalize(relax.op.repeat(x2, r1)) |
| with pytest.raises((TypeError, TVMError)): |
| bb.normalize(relax.op.repeat(x2, r2)) |
| |
| |
| def test_tile_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", R.Tensor((2, 10, 4))) |
| x4 = relax.Var("x", R.Tensor(ndim=3)) |
| x5 = relax.Var("x", R.Tensor()) |
| x6 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0)) |
| |
| _check_inference( |
| bb, |
| relax.op.tile(x0, 2), |
| relax.TensorStructInfo((2, 10, 8), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.tile(x6, 2), |
| relax.TensorStructInfo((2, 10, 8), "float32", vdev0), |
| ) |
| _check_inference( |
| bb, |
| relax.op.tile(x0, (3, 2)), |
| relax.TensorStructInfo((2, 30, 8), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.tile(x0, (4, 3, 2)), |
| relax.TensorStructInfo((8, 30, 8), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.tile(x0, (5, 4, 3, 2)), |
| relax.TensorStructInfo((5, 8, 30, 8), "float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.tile(x1, 2), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| ) |
| _check_inference( |
| bb, |
| relax.op.tile(x1, (5, 4, 3, 2)), |
| relax.TensorStructInfo(dtype="float32", ndim=4), |
| ) |
| _check_inference(bb, relax.op.tile(x2, (5, 4, 3, 2)), relax.TensorStructInfo(dtype="float32")) |
| _check_inference( |
| bb, |
| relax.op.tile(x3, 2), |
| relax.TensorStructInfo((2, 10, 8), dtype=""), |
| ) |
| _check_inference( |
| bb, |
| relax.op.tile(x3, (5, 4, 3, 2)), |
| relax.TensorStructInfo((5, 8, 30, 8), dtype=""), |
| ) |
| _check_inference(bb, relax.op.tile(x4, 2), relax.TensorStructInfo(dtype="", ndim=3)) |
| _check_inference(bb, relax.op.tile(x4, (5, 4, 3, 2)), relax.TensorStructInfo(dtype="", ndim=4)) |
| _check_inference(bb, relax.op.tile(x5, (5, 4, 3, 2)), relax.TensorStructInfo(dtype="")) |
| |
| |
| def test_tile_infer_struct_info_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| c = tir.Var("c", "int64") |
| x = relax.Var("x", R.Tensor((a, b, c), "float32")) |
| |
| _check_inference(bb, relax.op.tile(x, 2), relax.TensorStructInfo((a, b, c * 2), "float32")) |
| _check_inference( |
| bb, relax.op.tile(x, (3, 2)), relax.TensorStructInfo((a, b * 3, c * 2), "float32") |
| ) |
| _check_inference( |
| bb, relax.op.tile(x, (4, 3, 2)), relax.TensorStructInfo((a * 4, b * 3, c * 2), "float32") |
| ) |
| _check_inference( |
| bb, |
| relax.op.tile(x, (5, 4, 3, 2)), |
| relax.TensorStructInfo((5, a * 4, b * 3, c * 2), "float32"), |
| ) |
| |
| |
| def test_tile_infer_struct_info_more_input_dtype(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) |
| x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) |
| |
| _check_inference(bb, relax.op.tile(x0, (3, 2)), relax.TensorStructInfo((2, 9, 8), "float16")) |
| _check_inference(bb, relax.op.tile(x1, (3, 2)), relax.TensorStructInfo((2, 9, 8), "int8")) |
| |
| |
| def test_tile_return_data_sinfo(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| |
| _check_inference(bb, relax.op.tile(x0, 1), x0.struct_info) |
| _check_inference(bb, relax.op.tile(x0, (1, 1)), x0.struct_info) |
| _check_inference(bb, relax.op.tile(x0, (1, 1, 1)), x0.struct_info) |
| _check_inference(bb, relax.op.tile(x1, 1), x1.struct_info) |
| _check_inference(bb, relax.op.tile(x2, 1), x2.struct_info) |
| |
| |
| def test_tile_infer_struct_info_wrong_input_type(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) |
| x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) |
| x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) |
| r1 = tir.Var("a", "float32") |
| r2 = tir.StringImm("abc") |
| |
| with pytest.raises((TypeError, TVMError)): |
| bb.normalize(relax.op.tile(x0, 2)) |
| with pytest.raises((TypeError, TVMError)): |
| bb.normalize(relax.op.tile(x1, 2)) |
| with pytest.raises((TypeError, TVMError)): |
| bb.normalize(relax.op.tile(x2, (2, 1.5, 2))) |
| with pytest.raises((TypeError, TVMError)): |
| bb.normalize(relax.op.tile(x2, (2, r1))) |
| with pytest.raises((TypeError, TVMError)): |
| bb.normalize(relax.op.tile(x2, r2)) |
| |
| |
| def test_flip_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float16", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("int32")) |
| x3 = relax.Var("x", R.Tensor((2, 10, 4))) |
| x4 = relax.Var("x", R.Tensor(ndim=3)) |
| x5 = relax.Var("x", R.Tensor((2, 10, 4), "float32", vdev0)) |
| |
| _check_inference(bb, relax.op.flip(x0, axis=1), relax.TensorStructInfo((2, 10, 4), "float32")) |
| _check_inference( |
| bb, relax.op.flip(x5, axis=1), relax.TensorStructInfo((2, 10, 4), "float32", vdev0) |
| ) |
| _check_inference(bb, relax.op.flip(x1, axis=0), R.Tensor("float16", ndim=3)) |
| _check_inference(bb, relax.op.flip(x2, axis=0), R.Tensor("int32")) |
| _check_inference(bb, relax.op.flip(x3, axis=2), R.Tensor((2, 10, 4))) |
| _check_inference(bb, relax.op.flip(x4, axis=2), R.Tensor(ndim=3)) |
| |
| |
| def test_flip_infer_struct_info_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| x = relax.Var("x", R.Tensor((a, b), "float32")) |
| |
| _check_inference(bb, relax.op.flip(x, axis=0), relax.TensorStructInfo((a, b), "float32")) |
| |
| |
| def test_flip_infer_struct_info_wrong_inputs(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.flip(x0, axis=3)) |
| |
| |
| def test_gather_elements_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0)) |
| i0 = relax.Var("i", R.Tensor((2, 3, 4), "int64")) |
| i1 = relax.Var("i", R.Tensor((2, 3, 4))) |
| i2 = relax.Var("i", R.Tensor("int64", ndim=3)) |
| i3 = relax.Var("i", R.Tensor(ndim=3)) |
| i4 = relax.Var("i", R.Tensor((2, 3, 4), "int64", vdev0)) |
| |
| _check_inference( |
| bb, relax.op.gather_elements(x0, i0, axis=1), relax.TensorStructInfo((2, 3, 4), "float32") |
| ) |
| _check_inference( |
| bb, |
| relax.op.gather_elements(x3, i4, axis=1), |
| relax.TensorStructInfo((2, 3, 4), "float32", vdev0), |
| ) |
| _check_inference( |
| bb, |
| relax.op.gather_elements(x1, i0, axis=1), |
| relax.TensorStructInfo((2, 3, 4), dtype="float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.gather_elements(x2, i0, axis=0), |
| relax.TensorStructInfo(dtype="float32", ndim=-1), |
| ) |
| _check_inference( |
| bb, relax.op.gather_elements(x0, i1, axis=1), relax.TensorStructInfo((2, 3, 4), "float32") |
| ) |
| _check_inference( |
| bb, |
| relax.op.gather_elements(x1, i2, axis=1), |
| relax.TensorStructInfo(dtype="float32", ndim=3), |
| ) |
| _check_inference( |
| bb, |
| relax.op.gather_elements(x2, i3, axis=0), |
| relax.TensorStructInfo(dtype="float32", ndim=-1), |
| ) |
| |
| |
| def test_gather_elements_infer_struct_info_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| x = relax.Var("x", R.Tensor((a, b), "float32")) |
| i = relax.Var("i", R.Tensor((a, b), "int64")) |
| |
| _check_inference( |
| bb, relax.op.gather_elements(x, i, axis=1), relax.TensorStructInfo((a, b), "float32") |
| ) |
| |
| |
| def test_gather_elements_infer_struct_info_wrong_inputs(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor((2, 3), "float32")) |
| i0 = relax.Var("i", R.Tensor((2, 3, 4), "int64")) |
| i1 = relax.Var("i", R.Tensor((2, 3), "int64")) |
| i2 = relax.Var("i", R.Tensor((2, 3, 4), "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.gather_elements(x0, i0, axis=3)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.gather_elements(x0, i1)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.gather_elements(x1, i0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.gather_elements(x0, i2)) |
| |
| |
| def test_gather_nd_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) |
| x1 = relax.Var("x", R.Tensor("float32", ndim=3)) |
| x2 = relax.Var("x", R.Tensor("float32")) |
| x3 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0)) |
| i0 = relax.Var("i", R.Tensor((2, 2), "int64")) |
| i1 = relax.Var("i", R.Tensor((2, 2))) |
| i2 = relax.Var("i", R.Tensor("int64", ndim=2)) |
| i3 = relax.Var("i", R.Tensor(ndim=2)) |
| i4 = relax.Var("i", R.Tensor((2, 2), "int64", vdev0)) |
| |
| _check_inference(bb, relax.op.gather_nd(x0, i0), relax.TensorStructInfo((2, 4), "float32")) |
| _check_inference( |
| bb, relax.op.gather_nd(x3, i4), relax.TensorStructInfo((2, 4), "float32", vdev0) |
| ) |
| _check_inference( |
| bb, relax.op.gather_nd(x1, i0), relax.TensorStructInfo(dtype="float32", ndim=2) |
| ) |
| _check_inference( |
| bb, relax.op.gather_nd(x2, i0), relax.TensorStructInfo(dtype="float32", ndim=-1) |
| ) |
| _check_inference(bb, relax.op.gather_nd(x0, i1), relax.TensorStructInfo((2, 4), "float32")) |
| _check_inference(bb, relax.op.gather_nd(x1, i2), relax.TensorStructInfo(dtype="float32")) |
| _check_inference(bb, relax.op.gather_nd(x2, i3), relax.TensorStructInfo(dtype="float32")) |
| |
| |
| def test_gather_nd_infer_struct_info_shape_symbolic(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| c = tir.Var("c", "int64") |
| x = relax.Var("x", R.Tensor((a, b, c), "float32")) |
| i = relax.Var("i", R.Tensor((2, 2), "int64")) |
| |
| _check_inference(bb, relax.op.gather_nd(x, i), relax.TensorStructInfo((2, c), "float32")) |
| |
| |
| def test_gather_nd_infer_struct_info_wrong_inputs(): |
| bb = relax.BlockBuilder() |
| x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) |
| i0 = relax.Var("i", R.Tensor((2, 4), "int64")) # indices too long |
| i1 = relax.Var("i", R.Tensor((2, 2), "float32")) # wrong dtype |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.gather_nd(x0, i0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.gather_nd(x0, i1)) |
| |
| |
| def test_scatter_elements_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| d0 = relax.Var("data", R.Tensor((4, 4), "float32")) |
| d1 = relax.Var("data", R.Tensor(dtype="float32", ndim=2)) |
| d2 = relax.Var("data", R.Tensor("float32")) |
| d3 = relax.Var("data", R.Tensor((4, 4), "float32", vdev0)) |
| i0 = relax.Var("indices", R.Tensor((2, 2), "int64")) |
| i1 = relax.Var("indices", R.Tensor((2, 2))) |
| i2 = relax.Var("indices", R.Tensor(dtype="int64", ndim=2)) |
| i3 = relax.Var("indices", R.Tensor(ndim=2)) |
| i4 = relax.Var("indices", R.Tensor((2, 2), "int64", vdev0)) |
| u0 = relax.Var("updates", R.Tensor((2, 2), "float32")) |
| u1 = relax.Var("updates", R.Tensor((2, 2), "float32", vdev0)) |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d0, i0, u0, 0, "updates"), |
| relax.TensorStructInfo((4, 4), dtype="float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d3, i4, u1, 0, "updates"), |
| relax.TensorStructInfo((4, 4), dtype="float32", vdevice=vdev0), |
| ) |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d1, i0, u0, 0, "updates"), |
| relax.TensorStructInfo(dtype="float32", ndim=2), |
| ) |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d2, i0, u0, 0, "updates"), |
| relax.TensorStructInfo(dtype="float32", ndim=-1), |
| ) |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d0, i1, u0, 0, "updates"), |
| relax.TensorStructInfo((4, 4), dtype="float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d1, i1, u0, 0, "updates"), |
| relax.TensorStructInfo(dtype="float32", ndim=2), |
| ) |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d2, i1, u0, 0, "updates"), |
| relax.TensorStructInfo(dtype="float32", ndim=-1), |
| ) |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d0, i2, u0, 0, "updates"), |
| relax.TensorStructInfo((4, 4), dtype="float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d1, i2, u0, 0, "updates"), |
| relax.TensorStructInfo(dtype="float32", ndim=2), |
| ) |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d2, i2, u0, 0, "updates"), |
| relax.TensorStructInfo(dtype="float32", ndim=-1), |
| ) |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d0, i3, u0, 0, "updates"), |
| relax.TensorStructInfo((4, 4), dtype="float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d1, i3, u0, 0, "updates"), |
| relax.TensorStructInfo(dtype="float32", ndim=2), |
| ) |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d2, i3, u0, 0, "updates"), |
| relax.TensorStructInfo(dtype="float32", ndim=-1), |
| ) |
| |
| |
| def test_scatter_elements_infer_struct_info_symbolic_shape(): |
| bb = relax.BlockBuilder() |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| c = tir.Var("c", "int64") |
| d = tir.Var("d", "int64") |
| e = tir.Var("e", "int64") |
| f = tir.Var("f", "int64") |
| |
| d0 = relax.Var("data", R.Tensor((a, b), "float32")) |
| i0 = relax.Var("indices", R.Tensor((c, d), "int64")) |
| u0 = relax.Var("updates", R.Tensor((c, d), "float32")) |
| u1 = relax.Var("updates", R.Tensor((e, f), "float32")) |
| |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d0, i0, u0, 0, "updates"), |
| relax.TensorStructInfo((a, b), dtype="float32"), |
| ) |
| _check_inference( |
| bb, |
| relax.op.scatter_elements(d0, i0, u1, 0, "updates"), |
| relax.TensorStructInfo((a, b), dtype="float32"), |
| ) |
| |
| |
| def test_scatter_elements_infer_struct_info_wrong_indices_type(): |
| bb = relax.BlockBuilder() |
| d0 = relax.Var("data", R.Tensor((4, 4), "float32")) |
| i0 = relax.Var("indices", R.Tensor((2, 2), "float32")) |
| u0 = relax.Var("updates", R.Tensor((2, 2), "float32")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.scatter_elements(d0, i0, u0)) |
| |
| |
| def test_scatter_elements_infer_struct_info_rank_shape_mismatch(): |
| a = tir.Var("a", "int64") |
| b = tir.Var("b", "int64") |
| |
| bb = relax.BlockBuilder() |
| d0 = relax.Var("data", R.Tensor((4, 4), "float32")) |
| i0 = relax.Var("indices", R.Tensor((3, 3), "int64")) |
| i1 = relax.Var("indices", R.Tensor((3, 3, 3), "int64")) |
| i2 = relax.Var("indices", R.Tensor((a, b), "int64")) |
| u0 = relax.Var("updates", R.Tensor((3, 2), "float32")) |
| u1 = relax.Var("updates", R.Tensor((3, 2, 3), "float32")) |
| u2 = relax.Var("updates", R.Tensor((3, 3, 3), "float32")) |
| u3 = relax.Var("updates", R.Tensor((a + 1, b), "float32")) |
| u4 = relax.Var("updates", R.Tensor((3, 3), "float16")) |
| |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.scatter_elements(d0, i0, u0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.scatter_elements(d0, i1, u0)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.scatter_elements(d0, i0, u1)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.scatter_elements(d0, i1, u1)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.scatter_elements(d0, i1, u2)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.scatter_elements(d0, i2, u3)) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.scatter_elements(d0, i0, u4)) |
| |
| |
| def test_scatter_nd_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| |
| d0 = relax.Var("data", R.Tensor((8,), "float32")) |
| i0 = relax.Var("indices", R.Tensor((4, 1), "int64")) |
| u0 = relax.Var("updates", R.Tensor((4,), "float32")) |
| |
| _check_inference( |
| bb, |
| relax.op.scatter_nd(d0, i0, u0, "update"), |
| relax.TensorStructInfo((8,), dtype="float32"), |
| ) |
| |
| d1 = relax.Var("data", R.Tensor((4, 4, 4), "float32")) |
| i1 = relax.Var("indices", R.Tensor((2, 1), "int64")) |
| u1 = relax.Var("updates", R.Tensor((2, 4, 4), "float32")) |
| |
| _check_inference( |
| bb, |
| relax.op.scatter_nd(d1, i1, u1, "update"), |
| relax.TensorStructInfo((4, 4, 4), dtype="float32"), |
| ) |
| |
| |
| def test_meshgrid_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| vdev0 = VDevice("llvm") |
| t0 = relax.Var("t0", R.Tensor((3,), "float32")) |
| t1 = relax.Var("t1", R.Tensor((4,), "float32")) |
| t2 = relax.Var("t2", R.Tensor("float32", ndim=1)) |
| t3 = relax.Var("t3", R.Tensor((5,), "float32", vdev0)) |
| |
| _check_inference( |
| bb, |
| relax.op.meshgrid((t0, t1), indexing="ij"), |
| relax.TupleStructInfo( |
| [relax.TensorStructInfo((3, 4), "float32"), relax.TensorStructInfo((3, 4), "float32")] |
| ), |
| ) |
| |
| _check_inference( |
| bb, |
| relax.op.meshgrid((t3, t1), indexing="ij"), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((5, 4), "float32", vdev0), |
| relax.TensorStructInfo((5, 4), "float32", vdev0), |
| ] |
| ), |
| ) |
| |
| _check_inference( |
| bb, |
| relax.op.meshgrid((t2, t1), indexing="xy"), |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo(dtype="float32", ndim=2), |
| relax.TensorStructInfo(dtype="float32", ndim=2), |
| ] |
| ), |
| ) |
| |
| _check_inference( |
| bb, |
| relax.op.meshgrid((t0,), indexing="ij"), |
| relax.TupleStructInfo([relax.TensorStructInfo((3,), "float32")]), |
| ) |
| |
| |
| def test_one_hot_infer_struct_info(): |
| bb = relax.BlockBuilder() |
| |
| # Test case 1: Basic usage |
| i0 = relax.Var("indices", R.Tensor((3,), "int32")) |
| _check_inference( |
| bb, |
| relax.op.one_hot(i0, relax.PrimValue(1.0), relax.PrimValue(0.0), 5), |
| relax.TensorStructInfo((3, 5), "float32"), |
| ) |
| |
| # Test case 2: With specified axis |
| i1 = relax.Var("indices", R.Tensor((2, 2), "int32")) |
| _check_inference( |
| bb, |
| relax.op.one_hot(i1, relax.PrimValue(1), relax.PrimValue(0), 3, axis=1), |
| relax.TensorStructInfo((2, 3, 2), "int64"), |
| ) |
| |
| # Test case 3: With symbolic shape |
| n = tir.Var("n", "int64") |
| i2 = relax.Var("indices", R.Tensor((n,), "int32")) |
| _check_inference( |
| bb, |
| relax.op.one_hot(i2, relax.PrimValue(1.0), relax.PrimValue(0.0), 4), |
| relax.TensorStructInfo((n, 4), "float32"), |
| ) |
| |
| # Test case 4: With unknown shape |
| i3 = relax.Var("indices", R.Tensor("int32")) |
| _check_inference( |
| bb, |
| relax.op.one_hot(i3, relax.PrimValue(1.0), relax.PrimValue(0.0), 6), |
| relax.TensorStructInfo(dtype="float32"), |
| ) |
| |
| # Test case 5: With different on_value and off_value dtypes |
| i3 = relax.Var("indices", R.Tensor((2, 3), "int32")) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.one_hot(i3, relax.PrimValue(1.0), relax.PrimValue(0), 5)) |
| |
| # Test case 6: With invalid indices dtype |
| i4 = relax.Var("indices", R.Tensor((2, 3), "float32")) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.one_hot(i4, relax.PrimValue(1.0), relax.PrimValue(0.0), 5)) |
| |
| # Test case 7: With invalid depth |
| i5 = relax.Var("indices", R.Tensor((2, 3), "int32")) |
| with pytest.raises(TVMError): |
| bb.normalize(relax.op.one_hot(i5, relax.PrimValue(1.0), relax.PrimValue(0.0), -1)) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |