blob: e0992d276fbb28ece4c607e848efa50f3f828ded [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
from tvm.ir import Op, VDevice
from tvm.script import ir as I, relax as R, tir as T
import numpy as np
def test_op_correctness():
x = relax.Var("x", R.Tensor((2, 3), "float32"))
idx = relax.Var("idx", R.Tensor((2,), "float32"))
assert relax.op.take(x, idx, axis=1).op == Op.get("relax.take")
assert relax.op.strided_slice(x, axes=[0], begin=[0], end=[2]).op == Op.get(
"relax.strided_slice"
)
assert relax.op.dynamic_strided_slice(x, x, x, x).op == Op.get("relax.dynamic_strided_slice")
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo):
ret = bb.normalize(call)
tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
def test_take_infer_struct_info():
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((4, 10), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=2))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((4, 10)))
x4 = relax.Var("x", R.Tensor(ndim=2))
x5 = relax.Var("x", R.Tensor())
x6 = relax.Var("x", R.Tensor((4, 10), "float32", vdev0))
y0 = relax.Var("y", R.Tensor((10,), "float32"))
y1 = relax.Var("y", R.Tensor("float32", ndim=1))
y2 = relax.Var("y", R.Tensor((10,)))
y3 = relax.Var("y", R.Tensor(ndim=1))
idx0 = relax.Var("idx", R.Tensor((6,), "int64"))
idx1 = relax.Var("idx", R.Tensor("int64", ndim=1))
idx2 = relax.Var("idx", R.Tensor((6,)))
idx3 = relax.Var("idx", R.Tensor(ndim=1))
idx4 = relax.Var("idx", R.Tensor((6, 4), "int64"))
idx5 = relax.Var("idx", R.Tensor("int64", ndim=2))
idx6 = relax.Var("idx", R.Tensor((6, 4)))
idx7 = relax.Var("idx", R.Tensor(ndim=2))
idx8 = relax.Var("idx", R.Tensor((6,), "int64", vdev0))
_check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((4, 6), "float32"))
_check_inference(
bb, relax.op.take(x6, idx8, axis=1), relax.TensorStructInfo((4, 6), "float32", vdev0)
)
_check_inference(
bb, relax.op.take(x0, idx0, axis=-1), relax.TensorStructInfo((4, 6), "float32")
)
_check_inference(
bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, relax.op.take(x3, idx0, axis=1), relax.TensorStructInfo((4, 6), dtype=""))
_check_inference(bb, relax.op.take(x4, idx0, axis=1), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(x5, idx0, axis=1), relax.TensorStructInfo(dtype=""))
_check_inference(
bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(
bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, relax.op.take(x3, idx1, axis=1), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(x4, idx1, axis=1), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(x5, idx1, axis=1), relax.TensorStructInfo(dtype=""))
_check_inference(bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo((4, 6), "float32"))
_check_inference(
bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, relax.op.take(x3, idx2, axis=1), relax.TensorStructInfo((4, 6), dtype=""))
_check_inference(bb, relax.op.take(x4, idx2, axis=1), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(x5, idx2, axis=1), relax.TensorStructInfo(dtype=""))
_check_inference(
bb, relax.op.take(x0, idx3, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(
bb, relax.op.take(x1, idx3, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(bb, relax.op.take(x2, idx3, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, relax.op.take(x3, idx3, axis=1), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(x4, idx3, axis=1), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(x5, idx3, axis=1), relax.TensorStructInfo(dtype=""))
_check_inference(
bb, relax.op.take(x0, idx4, axis=0), relax.TensorStructInfo((6, 4, 10), dtype="float32")
)
_check_inference(
bb, relax.op.take(x0, idx4, axis=1), relax.TensorStructInfo((4, 6, 4), dtype="float32")
)
_check_inference(
bb, relax.op.take(x1, idx4, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(bb, relax.op.take(x2, idx4, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(
bb, relax.op.take(x3, idx4, axis=1), relax.TensorStructInfo((4, 6, 4), dtype="")
)
_check_inference(bb, relax.op.take(x4, idx4, axis=1), relax.TensorStructInfo(dtype="", ndim=3))
_check_inference(bb, relax.op.take(x5, idx4, axis=1), relax.TensorStructInfo(dtype=""))
_check_inference(
bb, relax.op.take(x0, idx5, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(
bb, relax.op.take(x0, idx5, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(
bb, relax.op.take(x1, idx5, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(bb, relax.op.take(x2, idx5, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, relax.op.take(x3, idx5, axis=1), relax.TensorStructInfo(dtype="", ndim=3))
_check_inference(bb, relax.op.take(x4, idx5, axis=1), relax.TensorStructInfo(dtype="", ndim=3))
_check_inference(bb, relax.op.take(x5, idx5, axis=1), relax.TensorStructInfo(dtype=""))
_check_inference(
bb, relax.op.take(x0, idx6, axis=0), relax.TensorStructInfo((6, 4, 10), dtype="float32")
)
_check_inference(
bb, relax.op.take(x0, idx6, axis=1), relax.TensorStructInfo((4, 6, 4), dtype="float32")
)
_check_inference(
bb, relax.op.take(x1, idx6, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(bb, relax.op.take(x2, idx6, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(
bb, relax.op.take(x3, idx6, axis=1), relax.TensorStructInfo((4, 6, 4), dtype="")
)
_check_inference(bb, relax.op.take(x4, idx6, axis=1), relax.TensorStructInfo(dtype="", ndim=3))
_check_inference(bb, relax.op.take(x5, idx6, axis=1), relax.TensorStructInfo(dtype=""))
_check_inference(
bb, relax.op.take(x0, idx7, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(
bb, relax.op.take(x0, idx7, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(
bb, relax.op.take(x1, idx7, axis=1), relax.TensorStructInfo(dtype="float32", ndim=3)
)
_check_inference(bb, relax.op.take(x2, idx7, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, relax.op.take(x3, idx7, axis=1), relax.TensorStructInfo(dtype="", ndim=3))
_check_inference(bb, relax.op.take(x4, idx7, axis=1), relax.TensorStructInfo(dtype="", ndim=3))
_check_inference(bb, relax.op.take(x5, idx7, axis=1), relax.TensorStructInfo(dtype=""))
_check_inference(bb, relax.op.take(y0, idx0), relax.TensorStructInfo((6,), "float32"))
_check_inference(bb, relax.op.take(y1, idx0), relax.TensorStructInfo(dtype="float32", ndim=1))
_check_inference(bb, relax.op.take(y2, idx0), relax.TensorStructInfo((6,), dtype=""))
_check_inference(bb, relax.op.take(y3, idx0), relax.TensorStructInfo(dtype="", ndim=1))
_check_inference(bb, relax.op.take(y0, idx1), relax.TensorStructInfo(dtype="float32", ndim=1))
_check_inference(bb, relax.op.take(y1, idx1), relax.TensorStructInfo(dtype="float32", ndim=1))
_check_inference(bb, relax.op.take(y2, idx1), relax.TensorStructInfo(dtype="", ndim=1))
_check_inference(bb, relax.op.take(y3, idx1), relax.TensorStructInfo(dtype="", ndim=1))
_check_inference(bb, relax.op.take(y0, idx2), relax.TensorStructInfo((6,), "float32"))
_check_inference(bb, relax.op.take(y1, idx2), relax.TensorStructInfo(dtype="float32", ndim=1))
_check_inference(bb, relax.op.take(y2, idx2), relax.TensorStructInfo((6,), dtype=""))
_check_inference(bb, relax.op.take(y3, idx2), relax.TensorStructInfo(dtype="", ndim=1))
_check_inference(bb, relax.op.take(y0, idx3), relax.TensorStructInfo(dtype="float32", ndim=1))
_check_inference(bb, relax.op.take(y1, idx3), relax.TensorStructInfo(dtype="float32", ndim=1))
_check_inference(bb, relax.op.take(y2, idx3), relax.TensorStructInfo(dtype="", ndim=1))
_check_inference(bb, relax.op.take(y3, idx3), relax.TensorStructInfo(dtype="", ndim=1))
_check_inference(bb, relax.op.take(y0, idx4), relax.TensorStructInfo((6, 4), "float32"))
_check_inference(bb, relax.op.take(y1, idx4), relax.TensorStructInfo(dtype="float32", ndim=2))
_check_inference(bb, relax.op.take(y2, idx4), relax.TensorStructInfo((6, 4), dtype=""))
_check_inference(bb, relax.op.take(y3, idx4), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(y0, idx5), relax.TensorStructInfo(dtype="float32", ndim=2))
_check_inference(bb, relax.op.take(y1, idx5), relax.TensorStructInfo(dtype="float32", ndim=2))
_check_inference(bb, relax.op.take(y2, idx5), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(y3, idx5), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(y0, idx6), relax.TensorStructInfo((6, 4), "float32"))
_check_inference(bb, relax.op.take(y1, idx6), relax.TensorStructInfo(dtype="float32", ndim=2))
_check_inference(bb, relax.op.take(y2, idx6), relax.TensorStructInfo((6, 4), dtype=""))
_check_inference(bb, relax.op.take(y3, idx6), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(y0, idx7), relax.TensorStructInfo(dtype="float32", ndim=2))
_check_inference(bb, relax.op.take(y1, idx7), relax.TensorStructInfo(dtype="float32", ndim=2))
_check_inference(bb, relax.op.take(y2, idx7), relax.TensorStructInfo(dtype="", ndim=2))
_check_inference(bb, relax.op.take(y3, idx7), relax.TensorStructInfo(dtype="", ndim=2))
def test_take_infer_struct_info_scalar_tensor_index():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((4, 10), "float32"))
idx = relax.Var("idx", R.Tensor([], "int64"))
_check_inference(bb, relax.op.take(x0, idx, axis=0), relax.TensorStructInfo([10], "float32"))
_check_inference(bb, relax.op.take(x0, idx, axis=1), relax.TensorStructInfo([4], "float32"))
def test_take_infer_struct_info_prim_value_index():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((4, 10), "float32"))
idx = relax.Var("idx", R.Prim("int64"))
_check_inference(bb, relax.op.take(x0, idx, axis=0), relax.TensorStructInfo([10], "float32"))
_check_inference(bb, relax.op.take(x0, idx, axis=1), relax.TensorStructInfo([4], "float32"))
def test_take_infer_struct_info_shape_symbolic():
bb = relax.BlockBuilder()
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
i = tir.Var("i", "int64")
j = tir.Var("j", "int64")
k = tir.Var("k", "int64")
x0 = relax.Var("x", R.Tensor((m, n), "float32"))
x1 = relax.Var("x", R.Tensor((m, n)))
y0 = relax.Var("y", R.Tensor((n,), "float32"))
y1 = relax.Var("y", R.Tensor((n,)))
idx0 = relax.Var("idx", R.Tensor((i,), "int64"))
idx1 = relax.Var(
"idx",
R.Tensor(
(i,),
),
)
idx2 = relax.Var(
"idx",
R.Tensor(
(i, j, k),
),
)
_check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((m, i), "float32"))
_check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo((m, i), dtype=""))
_check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo((m, i), "float32"))
_check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo((m, i), dtype=""))
_check_inference(
bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo((m, i, j, k), dtype="")
)
_check_inference(
bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo((m, i, j, k), dtype="")
)
_check_inference(bb, relax.op.take(y0, idx0), relax.TensorStructInfo((i,), "float32"))
_check_inference(bb, relax.op.take(y1, idx0), relax.TensorStructInfo((i,), dtype=""))
_check_inference(bb, relax.op.take(y0, idx1), relax.TensorStructInfo((i,), "float32"))
_check_inference(bb, relax.op.take(y1, idx1), relax.TensorStructInfo((i,), dtype=""))
_check_inference(bb, relax.op.take(y0, idx2), relax.TensorStructInfo((i, j, k), "float32"))
_check_inference(bb, relax.op.take(y1, idx2), relax.TensorStructInfo((i, j, k), dtype=""))
def test_take_infer_struct_info_shape_var():
bb = relax.BlockBuilder()
sx0 = relax.Var("sx", relax.ShapeStructInfo((4, 10)))
sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=2))
sx2 = relax.Var("sx", relax.ShapeStructInfo())
sidx0 = relax.Var("sidx", relax.ShapeStructInfo((6,)))
sidx1 = relax.Var("sidx", relax.ShapeStructInfo(ndim=1))
x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32"))
x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32"))
x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32"))
x3 = relax.Var("x", R.Tensor((4, 10), "float32"))
idx0 = relax.Var("idx", relax.TensorStructInfo(sidx0, "int64"))
idx1 = relax.Var("idx", relax.TensorStructInfo(sidx1, "int64"))
idx2 = relax.Var("idx", R.Tensor((6,), "int64"))
_check_inference(
bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(
bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(
bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(
bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(
bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(
bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo(dtype="float32"))
_check_inference(
bb, relax.op.take(x3, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(
bb, relax.op.take(x3, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
def test_take_infer_struct_info_more_input_dtype():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((4, 10), "float16"))
x1 = relax.Var("x", R.Tensor((4, 10), "int16"))
x2 = relax.Var("x", R.Tensor((4, 10), "int32"))
idx0 = relax.Var("idx", R.Tensor((6,), "int32"))
idx1 = relax.Var("idx", R.Tensor((6,), "int8"))
idx2 = relax.Var("idx", R.Tensor((6,), "uint32"))
_check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((4, 6), "float16"))
_check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo((4, 6), "int16"))
_check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo((4, 6), "int32"))
_check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo((4, 6), "float16"))
_check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo((4, 6), "int16"))
_check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo((4, 6), "int32"))
_check_inference(bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo((4, 6), "float16"))
_check_inference(bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo((4, 6), "int16"))
_check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo((4, 6), "int32"))
def test_take_infer_struct_info_indices_not_integer_dtype():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((4, 10), "float32"))
idx0 = relax.Var("idx", R.Tensor((6, 6), "float32"))
idx1 = relax.Var("idx", R.Tensor((6, 6), "float64"))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x, idx0, axis=1))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x, idx1, axis=1))
def test_take_infer_struct_info_multi_dimensional_without_axis():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((4, 10), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=2))
x2 = relax.Var("x", R.Tensor("float32"))
idx0 = relax.Var("idx", R.Tensor((6,), "int64"))
idx1 = relax.Var("idx", R.Tensor("int64", ndim=1))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x0, idx0))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x1, idx0))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x2, idx0))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x0, idx1))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x1, idx1))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x2, idx1))
def test_take_infer_struct_info_axis_out_of_range():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((4, 10), "float32"))
idx = relax.Var("idx", R.Tensor((6,), "int64"))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x, idx, axis=-3))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x, idx, axis=2))
def test_take_infer_struct_info_wrong_input_type():
bb = relax.BlockBuilder()
x0 = relax.Var("x", relax.ShapeStructInfo((4, 10)))
x1 = relax.Var("x", R.Tensor((4, 10), "float32"))
idx0 = relax.Var("idx", relax.ShapeStructInfo((6,)))
idx1 = relax.Var("idx", R.Tensor((6,), "int64"))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x0, idx1, axis=1))
with pytest.raises(TVMError):
bb.normalize(relax.op.take(x1, idx0, axis=1))
def test_strided_slice_infer_struct_info():
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=4))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((8, 9, 10, 10)))
x4 = relax.Var("x", R.Tensor(ndim=4))
x5 = relax.Var("x", R.Tensor())
x6 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32", vdev0))
_check_inference(
bb,
relax.op.strided_slice(
x0, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3]
),
relax.TensorStructInfo((4, 9, 10, 3), "float32"),
)
_check_inference(
bb,
relax.op.strided_slice(
x6, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3]
),
relax.TensorStructInfo((4, 9, 10, 3), "float32", vdev0),
)
_check_inference(
bb,
relax.op.strided_slice(
x1, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3]
),
relax.TensorStructInfo(dtype="float32", ndim=4),
)
_check_inference(
bb,
relax.op.strided_slice(
x2, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3]
),
relax.TensorStructInfo(dtype="float32"),
)
_check_inference(
bb,
relax.op.strided_slice(
x3, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3]
),
relax.TensorStructInfo((4, 9, 10, 3), dtype=""),
)
_check_inference(
bb,
relax.op.strided_slice(
x4, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3]
),
relax.TensorStructInfo(dtype="", ndim=4),
)
_check_inference(
bb,
relax.op.strided_slice(
x5, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3]
),
relax.TensorStructInfo(dtype=""),
)
_check_inference(
bb,
relax.op.strided_slice(
x0, axes=[-1, -3, -4], begin=[8, 0, 1], end=[0, 9, 8], strides=[-3, 1, 2]
),
relax.TensorStructInfo((4, 9, 10, 3), "float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x0, axes=[1, 2], begin=[1, 0], end=[8, 9]),
relax.TensorStructInfo((8, 7, 9, 10), "float32"),
)
def test_strided_slice_infer_struct_info_shape_out_of_range():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((20, 10, 5), "float32"))
_check_inference(
bb,
relax.op.strided_slice(
x0, axes=[0, 1, 2], begin=[20, 10, 4], end=[0, 0, 1], strides=[-1, -3, -2]
),
relax.TensorStructInfo((19, 3, 2), "float32"),
)
_check_inference(
bb,
relax.op.strided_slice(
x0, axes=[0, 1, 2], begin=[200, 10, 4], end=[0, 0, 1], strides=[-1, -3, -2]
),
relax.TensorStructInfo((19, 3, 2), "float32"),
)
_check_inference(
bb,
relax.op.strided_slice(
x0, axes=[0, 1, 2], begin=[200, 10, 100], end=[0, 0, 1], strides=[-1, -3, -5]
),
relax.TensorStructInfo((19, 3, 1), "float32"),
)
_check_inference(
bb,
relax.op.strided_slice(
x0, axes=[0, 1, 2], begin=[-21, -11, -6], end=[1, 1, 1], strides=[1000, 1000, 1000]
),
relax.TensorStructInfo((1, 1, 1), "float32"),
)
def test_strided_slice_infer_struct_info_shape_symbolic():
bb = relax.BlockBuilder()
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x0 = relax.Var("x", R.Tensor((m, n), "float32"))
x1 = relax.Var("x", R.Tensor((m, n)))
_check_inference(
bb,
relax.op.strided_slice(x0, axes=[0], begin=[1], end=[3]),
relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m), n), "float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x0, axes=[0], begin=[1], end=[8], strides=[3]),
relax.TensorStructInfo(((tir.min(8, m) + 2 - tir.min(1, m)) // 3, n), "float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x1, axes=[0], begin=[1], end=[3]),
relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m), n), dtype=""),
)
_check_inference(
bb,
relax.op.strided_slice(x1, axes=[0], begin=[1], end=[8], strides=[3]),
relax.TensorStructInfo(((tir.min(8, m) + 2 - tir.min(1, m)) // 3, n), dtype=""),
)
def test_strided_slice_infer_struct_info_shape_var():
bb = relax.BlockBuilder()
s0 = relax.Var("s", relax.ShapeStructInfo((8, 10)))
s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
s2 = relax.Var("s", relax.ShapeStructInfo())
x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
x3 = relax.Var("x", relax.TensorStructInfo(s0, dtype=""))
x4 = relax.Var("x", relax.TensorStructInfo(s1, dtype=""))
x5 = relax.Var("x", relax.TensorStructInfo(s2, dtype=""))
_check_inference(
bb,
relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8]),
relax.TensorStructInfo(shape=[8, 10], dtype="float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]),
relax.TensorStructInfo(dtype="float32", ndim=2),
)
_check_inference(
bb,
relax.op.strided_slice(x2, axes=[0], begin=[0], end=[8]),
relax.TensorStructInfo(dtype="float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x3, axes=[0], begin=[0], end=[8]),
relax.TensorStructInfo(shape=[8, 10], dtype=""),
)
_check_inference(
bb,
relax.op.strided_slice(x4, axes=[0], begin=[0], end=[8]),
relax.TensorStructInfo(dtype="", ndim=2),
)
_check_inference(
bb,
relax.op.strided_slice(x5, axes=[0], begin=[0], end=[8]),
relax.TensorStructInfo(dtype=""),
)
def test_strided_slice_infer_struct_info_more_input_dtype():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((8, 9), "float16"))
x1 = relax.Var("x", R.Tensor((8, 9), "int32"))
x2 = relax.Var("x", R.Tensor((8, 9), "int64"))
_check_inference(
bb,
relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8]),
relax.TensorStructInfo((8, 9), "float16"),
)
_check_inference(
bb,
relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]),
relax.TensorStructInfo((8, 9), "int32"),
)
_check_inference(
bb,
relax.op.strided_slice(x2, axes=[0], begin=[0], end=[8]),
relax.TensorStructInfo((8, 9), "int64"),
)
def test_strided_slice_infer_struct_info_symbolic_begin_end_strides():
bb = relax.BlockBuilder()
var = tir.Var("var", "int64")
size_var = tir.SizeVar("size_var", "int64")
x = relax.Var("x", R.Tensor((8, 9), "float32"))
_check_inference(
bb,
relax.op.strided_slice(x, axes=[0], begin=[var], end=[8]),
relax.TensorStructInfo(
(tir.max(8 - tir.max(tir.if_then_else(var < 0, var + 8, var), 0), 0), 9),
dtype="float32",
),
)
_check_inference(
bb,
relax.op.strided_slice(x, axes=[0], begin=[size_var], end=[8]),
relax.TensorStructInfo((tir.max(8 - size_var, 0), 9), dtype="float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x, axes=[0], begin=[0], end=[var]),
relax.TensorStructInfo(
(tir.min(tir.max(tir.if_then_else(var < 0, var + 8, var), 0), 8), 9), dtype="float32"
),
)
_check_inference(
bb,
relax.op.strided_slice(x, axes=[0], begin=[0], end=[size_var]),
relax.TensorStructInfo((tir.min(size_var, 8), 9), dtype="float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var]),
relax.TensorStructInfo(
[tir.if_then_else(var < 0, -8 // (0 - var) + 1, (var + 7) // var), 9],
dtype="float32",
),
)
_check_inference(
bb,
relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[size_var]),
relax.TensorStructInfo([7 // size_var + 1, 9], dtype="float32"),
)
def test_strided_slice_infer_struct_info_symbolic_begin_end_strides_inbound():
bb = relax.BlockBuilder()
var = tir.Var("var", "int64")
size_var = tir.SizeVar("size_var", "int64")
x = relax.Var("x", R.Tensor((8, 9), "float32"))
_check_inference(
bb,
relax.op.strided_slice(x, axes=[0], begin=[var], end=[8], assume_inbound=True),
relax.TensorStructInfo(
(8 - var, 9),
dtype="float32",
),
)
_check_inference(
bb,
relax.op.strided_slice(x, axes=[0], begin=[size_var], end=[8], assume_inbound=True),
relax.TensorStructInfo((8 - size_var, 9), dtype="float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x, axes=[0], begin=[0], end=[var], assume_inbound=True),
relax.TensorStructInfo((var, 9), dtype="float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x, axes=[0], begin=[0], end=[size_var], assume_inbound=True),
relax.TensorStructInfo((size_var, 9), dtype="float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var], assume_inbound=True),
relax.TensorStructInfo([(var + 7) // var, 9], dtype="float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var], assume_inbound=True),
relax.TensorStructInfo([(var + 7) // var, 9], dtype="float32"),
)
def test_strided_slice_infer_struct_info_no_axis():
bb = relax.BlockBuilder()
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
s0 = relax.Var("s", relax.ShapeStructInfo((m, n)))
s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
s2 = relax.Var("s", relax.ShapeStructInfo())
x0 = relax.Var("x", R.Tensor((m, n), "float32"))
x1 = relax.Var("x", R.Tensor(dtype="float32", ndim=2))
x2 = relax.Var("x", R.Tensor(dtype="float32"))
x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
_check_inference(
bb,
relax.op.strided_slice(x0, axes=[], begin=[], end=[]),
relax.TensorStructInfo((m, n), "float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x1, axes=[], begin=[], end=[]),
relax.TensorStructInfo(dtype="float32", ndim=2),
)
_check_inference(
bb,
relax.op.strided_slice(x2, axes=[], begin=[], end=[]),
relax.TensorStructInfo(dtype="float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x3, axes=[], begin=[], end=[]),
relax.TensorStructInfo([m, n], "float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x4, axes=[], begin=[], end=[]),
relax.TensorStructInfo(s1, "float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x5, axes=[], begin=[], end=[]),
relax.TensorStructInfo(s2, "float32"),
)
def test_strided_slice_begin_end_strides_int64():
x = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
strided_slice = relax.op.strided_slice(
x, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3]
)
begins = strided_slice.args[1]
ends = strided_slice.args[2]
strides = strided_slice.args[3]
assert begins[0].struct_info.dtype == "int64"
assert begins[1].struct_info.dtype == "int64"
assert begins[2].struct_info.dtype == "int64"
assert ends[0].struct_info.dtype == "int64"
assert ends[1].struct_info.dtype == "int64"
assert ends[2].struct_info.dtype == "int64"
assert strides[0].struct_info.dtype == "int64"
assert strides[1].struct_info.dtype == "int64"
assert strides[2].struct_info.dtype == "int64"
def test_strided_slice_inconsistent_axes_begin_end_strides_length():
x = relax.Var("x", R.Tensor((8, 9), "float32"))
with pytest.raises(TVMError):
relax.op.strided_slice(x, axes=[1], begin=[], end=[9])
with pytest.raises(TVMError):
relax.op.strided_slice(x, axes=[1], begin=[0], end=[])
with pytest.raises(TVMError):
relax.op.strided_slice(x, axes=[1], begin=[0], end=[9], strides=[])
def test_strided_slice_infer_struct_info_repetitive_axes():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((8, 9), "float32"))
with pytest.raises(TVMError):
bb.normalize(relax.op.strided_slice(x, axes=[0, 0], begin=[0, 0], end=[8, 8]))
with pytest.raises(TVMError):
bb.normalize(relax.op.strided_slice(x, axes=[0, -2], begin=[0, 0], end=[8, 8]))
def test_strided_slice_infer_struct_info_axis_out_of_range():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((8, 9), "float32"))
with pytest.raises(TVMError):
bb.normalize(relax.op.strided_slice(x, axes=[2], begin=[0], end=[8]))
with pytest.raises(TVMError):
bb.normalize(relax.op.strided_slice(x, axes=[-3], begin=[0], end=[8]))
def test_strided_slice_infer_struct_info_wrong_input_type():
bb = relax.BlockBuilder()
x0 = relax.Var("x", relax.ShapeStructInfo((8, 9)))
x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((8, 9), "float32")))
with pytest.raises(TVMError):
bb.normalize(relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8]))
with pytest.raises(TVMError):
bb.normalize(relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]))
def test_dynamic_strided_slice_infer_struct_info():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=4))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((8, 9, 10, 10)))
x4 = relax.Var("x", R.Tensor(ndim=4))
x5 = relax.Var("x", R.Tensor())
b0 = relax.Var("begin", R.Tensor((4,), "int64"))
e0 = relax.Var("end", R.Tensor((4,), "int64"))
s0 = relax.Var("strides", R.Tensor((4,), "int64"))
b1 = relax.Var("begin", R.Tensor((4,)))
e1 = relax.Var("end", R.Tensor((4,)))
s1 = relax.Var("stride", R.Tensor((4,)))
_check_inference(
bb,
relax.op.dynamic_strided_slice(x0, b0, e0, s0),
R.Tensor("float32", ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x1, b0, e0, s0),
R.Tensor("float32", ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x2, b0, e0, s0),
R.Tensor("float32", ndim=-1),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x3, b0, e0, s0),
R.Tensor(ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x4, b0, e0, s0),
R.Tensor(ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x5, b0, e0, s0),
R.Tensor(ndim=-1),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x0, b1, e1, s1),
R.Tensor("float32", ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x1, b1, e1, s1),
R.Tensor("float32", ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x2, b1, e1, s1),
R.Tensor("float32", ndim=-1),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x3, b1, e1, s1),
R.Tensor(ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x4, b1, e1, s1),
R.Tensor(ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x5, b1, e1, s1),
R.Tensor(ndim=-1),
)
def test_dynamic_strided_slice_infer_struct_info_symbolic():
bb = relax.BlockBuilder()
i = tir.Var("i", "int64")
j = tir.Var("j", "int64")
k = tir.Var("k", "int64")
l = tir.Var("l", "int64")
x0 = relax.Var("x", R.Tensor((i, j, k, l), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=4))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((i, j, k, l)))
x4 = relax.Var("x", R.Tensor(ndim=4))
x5 = relax.Var("x", R.Tensor())
b0 = relax.Var("begin", R.Tensor((4,), "int64"))
e0 = relax.Var("end", R.Tensor((4,), "int64"))
s0 = relax.Var("stride", R.Tensor((4,), "int64"))
b1 = relax.Var("begin", R.Tensor((4,)))
e1 = relax.Var("end", R.Tensor((4,)))
s1 = relax.Var("stride", R.Tensor((4,)))
_check_inference(
bb,
relax.op.dynamic_strided_slice(x0, b0, e0, s0),
R.Tensor("float32", ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x1, b0, e0, s0),
R.Tensor("float32", ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x2, b0, e0, s0),
R.Tensor("float32", ndim=-1),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x3, b0, e0, s0),
R.Tensor(ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x4, b0, e0, s0),
R.Tensor(ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x5, b0, e0, s0),
R.Tensor(ndim=-1),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x0, b1, e1, s1),
R.Tensor("float32", ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x1, b1, e1, s1),
R.Tensor("float32", ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x2, b1, e1, s1),
R.Tensor("float32", ndim=-1),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x3, b1, e1, s1),
R.Tensor(ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x4, b1, e1, s1),
R.Tensor(ndim=4),
)
_check_inference(
bb,
relax.op.dynamic_strided_slice(x5, b1, e1, s1),
R.Tensor(ndim=-1),
)
def test_dynamic_strided_slice_infer_struct_info_arg_wrong_dtype():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
b0 = relax.Var("begin", R.Tensor((4,), "float32"))
e0 = relax.Var("end", R.Tensor((4,), "float32"))
s0 = relax.Var("stride", R.Tensor((4,), "float32"))
with pytest.raises(TVMError):
bb.normalize(relax.op.strided_slice(x0, b0, e0, s0))
def test_dynamic_strided_slice_infer_struct_info_arg_wrong_shape_info():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
m = tir.Var("m", "int64")
# invalid arg
b0 = relax.Var("begin", R.Tensor("int64", ndim=2))
b1 = relax.Var("begin", R.Tensor((1,), "int64"))
b2 = relax.Var("begin", R.Tensor((2, 2), "int64"))
b3 = relax.Var("begin", R.Tensor((m,), "int64"))
# valid args
e0 = relax.Var("end", R.Tensor((4,), "int64"))
s0 = relax.Var("stride", R.Tensor((4,), "int64"))
with pytest.raises(TVMError):
bb.normalize(relax.op.strided_slice(x0, b0, e0, s0))
with pytest.raises(TVMError):
bb.normalize(relax.op.strided_slice(x0, b1, e0, s0))
with pytest.raises(TVMError):
bb.normalize(relax.op.strided_slice(x0, b2, e0, s0))
with pytest.raises(TVMError):
bb.normalize(relax.op.strided_slice(x0, b3, e0, s0))
def test_legalize_dynamic_begin_end():
"""relax.op.strided_slice FLegalize must support dynamic begin/end"""
@I.ir_module
class before:
@R.function
def main(A: R.Tensor((16, 16), "float32"), B: R.Shape(["index"])) -> R.Tensor((1, 16)):
index = T.int64()
return R.strided_slice(A, [0], [index], [index + 1], assume_inbound=True)
@I.ir_module
class expected:
@R.function
def main(A: R.Tensor((16, 16), "float32"), B: R.Shape(["index"])) -> R.Tensor((1, 16)):
index = T.int64()
return R.call_tir(
expected.strided_slice,
(A,),
out_sinfo=R.Tensor((1, 16), "float32"),
tir_vars=R.shape([index]),
)
@T.prim_func(private=True)
def strided_slice(
A: T.Buffer((T.int64(16), T.int64(16))),
B: T.Buffer((T.int64(1), T.int64(16))),
index: T.int64,
):
T.func_attr({"tir.noalias": True})
for iters in T.grid(*B.shape):
with T.block("T_dynamic_strided_slice"):
i, j = T.axis.remap("SS", iters)
B[i, j] = A[i + index, j]
after = tvm.relax.transform.LegalizeOps()(before)
tvm.ir.assert_structural_equal(expected, after)
def test_legalize_dynamic_begin_inf_end():
"""relax.op.strided_slice FLegalize must support dynamic begin/end"""
@I.ir_module
class before:
@R.function
def main(A: R.Tensor((16, 16), "float32"), B: R.Shape(["index"])) -> R.Tensor((1, 16)):
index = T.int64()
return R.strided_slice(
A, [0], [index], [T.int64(np.iinfo(np.int64).max)], assume_inbound=False
)
# fmt: off
@I.ir_module
class expected:
@T.prim_func(private=True)
def strided_slice(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), var_T_dynamic_strided_slice_with_axes: T.handle, index: T.int64):
T.func_attr({"tir.noalias": True})
T_dynamic_strided_slice_with_axes = T.match_buffer(var_T_dynamic_strided_slice_with_axes, (T.max(T.int64(16) - T.max(T.if_then_else(index < T.int64(0), index + T.int64(16), index), T.int64(0)), T.int64(0)), T.int64(16)))
# with T.block("root"):
for ax0, ax1 in T.grid(T.max(T.int64(16) - T.max(T.if_then_else(index < T.int64(0), index + T.int64(16), index), T.int64(0)), T.int64(0)), T.int64(16)):
with T.block("T_dynamic_strided_slice_with_axes"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0 + index, v_ax1])
T.writes(T_dynamic_strided_slice_with_axes[v_ax0, v_ax1])
T_dynamic_strided_slice_with_axes[v_ax0, v_ax1] = A[v_ax0 + index, v_ax1]
@R.function
def main(A: R.Tensor((16, 16), dtype="float32"), B: R.Shape(["index"])) -> R.Tensor(("T.max(16 - T.max(T.if_then_else(index < 0, index + 16, index), 0), 0)", 16), dtype="float32"):
index = T.int64()
cls = expected
gv = R.call_tir(cls.strided_slice, (A,), out_sinfo=R.Tensor((T.max(16 - T.max(T.if_then_else(index < 0, index + 16, index), 0), 0), 16), dtype="float32"), tir_vars=R.shape([index]))
return gv
# fmt: on
after = tvm.relax.transform.LegalizeOps()(before)
tvm.ir.assert_structural_equal(expected, after)
if __name__ == "__main__":
tvm.testing.main()