blob: 660b5d27720bcdc91188eaf4b2815fe175b158a9 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest
import tvm
import tvm.testing
from tvm import TVMError, relax, tir
from tvm.relax.transform import LegalizeOps
from tvm.script import relax as R
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo):
ret = bb.normalize(call)
tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
def test_all_class_non_max_suppression_infer_struct_info():
bb = relax.BlockBuilder()
batch_size, num_classes, num_boxes = 10, 8, 5
boxes = relax.Var("boxes", R.Tensor((batch_size, num_boxes, 4), "float32"))
scores = relax.Var("scores", R.Tensor((batch_size, num_classes, num_boxes), "float32"))
max_output_boxes_per_class = relax.const(10, "int64")
iou_threshold = relax.const(0.5, "float32")
score_threshold = relax.const(0.1, "float32")
_check_inference(
bb,
relax.op.vision.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx"
),
relax.TupleStructInfo(
[
relax.TensorStructInfo((batch_size * num_classes * num_boxes, 3), "int64"),
relax.TensorStructInfo((1,), "int64"),
]
),
)
def test_all_class_non_max_suppression_wrong_input_number():
boxes = relax.Var("boxes", R.Tensor((1, 5, 4), "float32"))
scores = relax.Var("scores", R.Tensor((1, 3, 5), "float32"))
with pytest.raises(TVMError):
relax.op.vision.all_class_non_max_suppression(boxes, scores)
def test_all_class_non_max_suppression_infer_struct_info_shape_var():
bb = relax.BlockBuilder()
batch_size = tir.Var("batch_size", "int64")
num_classes = tir.Var("num_classes", "int64")
num_boxes = tir.Var("num_boxes", "int64")
boxes = relax.Var("boxes", R.Tensor((batch_size, num_boxes, 4), "float32"))
scores = relax.Var("scores", R.Tensor((batch_size, num_classes, num_boxes), "float32"))
max_output_boxes_per_class = relax.const(10, "int64")
iou_threshold = relax.const(0.5, "float32")
score_threshold = relax.const(0.1, "float32")
_check_inference(
bb,
relax.op.vision.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx"
),
relax.TupleStructInfo(
[
relax.TensorStructInfo((batch_size * num_classes * num_boxes, 3), "int64"),
relax.TensorStructInfo((1,), "int64"),
]
),
)
def test_all_class_non_max_suppression_legalize_dynamic_trim():
@tvm.script.ir_module
class NMSModule:
@R.function
def main(
boxes: R.Tensor((1, 5, 4), "float32"),
scores: R.Tensor((1, 2, 5), "float32"),
) -> R.Tuple(R.Tensor(dtype="int64", ndim=2), R.Tensor((1,), "int64")):
max_output_boxes_per_class = R.const(3, "int64")
iou_threshold = R.const(0.5, "float32")
score_threshold = R.const(0.1, "float32")
return R.vision.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx"
)
mod = LegalizeOps()(NMSModule)
# Check legalized function has dynamic output (uses dynamic_strided_slice)
assert "dynamic_strided_slice" in str(mod)
ret_sinfo = mod["main"].ret_struct_info
tvm.ir.assert_structural_equal(
ret_sinfo,
relax.TupleStructInfo(
[
relax.TensorStructInfo(ndim=2, dtype="int64"),
relax.TensorStructInfo((1,), "int64"),
]
),
)
def test_all_class_non_max_suppression_legalize_e2e():
@tvm.script.ir_module
class NMSModule:
@R.function
def main(
boxes: R.Tensor((1, 5, 4), "float32"),
scores: R.Tensor((1, 2, 5), "float32"),
) -> R.Tuple(R.Tensor(dtype="int64", ndim=2), R.Tensor((1,), "int64")):
max_output_boxes_per_class = R.const(3, "int64")
iou_threshold = R.const(0.5, "float32")
score_threshold = R.const(0.1, "float32")
return R.vision.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx"
)
boxes_data = np.array(
[
[
[0.0, 0.0, 1.0, 1.0],
[0.1, 0.1, 1.1, 1.1],
[2.0, 2.0, 3.0, 3.0],
[4.0, 4.0, 5.0, 5.0],
[6.0, 6.0, 7.0, 7.0],
]
],
dtype=np.float32,
)
scores_data = np.array(
[[[0.9, 0.8, 0.7, 0.6, 0.5], [0.85, 0.75, 0.65, 0.55, 0.45]]],
dtype=np.float32,
)
mod = LegalizeOps()(NMSModule)
# Check struct info
tvm.ir.assert_structural_equal(
mod["main"].ret_struct_info,
relax.TupleStructInfo(
[
relax.TensorStructInfo(ndim=2, dtype="int64"),
relax.TensorStructInfo((1,), "int64"),
]
),
)
# Check runtime execution
exe = tvm.compile(mod, target="llvm")
vm = relax.VirtualMachine(exe, tvm.cpu())
result = vm["main"](
tvm.runtime.tensor(boxes_data, tvm.cpu()),
tvm.runtime.tensor(scores_data, tvm.cpu()),
)
selected_indices = result[0].numpy()
num_total_detections = int(result[1].numpy()[0])
tvm.testing.assert_allclose(selected_indices.shape, (num_total_detections, 3))
if __name__ == "__main__":
tvm.testing.main()