blob: 59315b605df145a0790762a37b45da0179cfea76 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# ruff: noqa: E722
"""Test Tools in MSC."""
import json
import pytest
import torch
import tvm.testing
from tvm.contrib.msc.core import utils as msc_utils
from tvm.contrib.msc.core.tools import ToolType
from tvm.contrib.msc.core.utils.namespace import MSCFramework
from tvm.contrib.msc.pipeline import MSCManager
requires_tensorrt = pytest.mark.skipif(
tvm.get_global_func("relax.ext.tensorrt", True) is None,
reason="TENSORRT is not enabled",
)
def _get_config(
model_type,
compile_type,
tools,
inputs,
outputs,
atol=1e-2,
rtol=1e-2,
optimize_type=None,
):
"""Get msc config"""
path = "_".join(["test_tools", model_type, compile_type] + [t["tool_type"] for t in tools])
return {
"workspace": msc_utils.msc_dir(path, keep_history=False),
"verbose": "critical",
"model_type": model_type,
"inputs": inputs,
"outputs": outputs,
"dataset": {"prepare": {"loader": "from_random", "max_iter": 5}},
"tools": tools,
"prepare": {"profile": {"benchmark": {"repeat": 10}}},
"baseline": {
"run_type": model_type,
"profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}},
},
"optimize": {
"run_type": optimize_type or model_type,
"profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}},
},
"compile": {
"run_type": compile_type,
"profile": {"check": {"atol": atol, "rtol": rtol}, "benchmark": {"repeat": 10}},
},
}
def get_tools(tool_type, use_distill=False, run_type=MSCFramework.MSC):
"""Get config for the tool"""
tools = []
if tool_type == ToolType.PRUNER:
config = {
"plan_file": "msc_pruner.json",
"strategys": [
{
"methods": {
"weights": {"method_name": "per_channel", "density": 0.8},
"output": {"method_name": "per_channel", "density": 0.8},
}
}
],
}
tools.append({"tool_type": ToolType.PRUNER, "tool_config": config})
elif tool_type == ToolType.QUANTIZER:
# pylint: disable=import-outside-toplevel
from tvm.contrib.msc.core.tools.quantize import QuantizeStage
if run_type == MSCFramework.TENSORRT:
config = {"plan_file": "msc_quantizer.json", "strategys": []}
else:
op_types = ["nn.conv2d", "msc.conv2d_bias", "msc.linear", "msc.linear_bias"]
config = {
"plan_file": "msc_quantizer.json",
"strategys": [
{
"methods": {
"input": "gather_maxmin",
"output": "gather_maxmin",
"weights": "gather_max_per_channel",
},
"op_types": op_types,
"stages": [QuantizeStage.GATHER],
},
{
"methods": {"input": "calibrate_maxmin", "output": "calibrate_maxmin"},
"op_types": op_types,
"stages": [QuantizeStage.CALIBRATE],
},
{
"methods": {
"input": "quantize_normal",
"weights": "quantize_normal",
"output": "dequantize_normal",
},
"op_types": op_types,
},
],
}
tools.append({"tool_type": ToolType.QUANTIZER, "tool_config": config})
elif tool_type == ToolType.TRACKER:
# pylint: disable=import-outside-toplevel
from tvm.contrib.msc.core.utils import MSCStage
config = {
"plan_file": "msc_tracker.json",
"strategys": [
{
"methods": {
"output": {
"method_name": "save_compared",
"compare_to": {
MSCStage.OPTIMIZE: [MSCStage.BASELINE],
MSCStage.COMPILE: [MSCStage.OPTIMIZE, MSCStage.BASELINE],
},
}
},
"op_types": ["nn.relu"],
}
],
}
tools.append({"tool_type": ToolType.TRACKER, "tool_config": config})
if use_distill:
config = {
"plan_file": "msc_distiller.json",
"strategys": [
{
"methods": {"mark": "loss_lp_norm"},
"marks": ["loss"],
},
],
}
tools.append({"tool_type": ToolType.DISTILLER, "tool_config": config})
return tools
def _get_torch_model(name, training=False):
"""Get model from torch vision"""
# pylint: disable=import-outside-toplevel
try:
import torchvision
model = getattr(torchvision.models, name)()
if training:
model = model.train()
else:
model = model.eval()
return model
except: # pylint: disable=bare-except
print("please install torchvision package")
return None
def _check_manager(manager, expected_info):
"""Check the manager results"""
model_info = manager.get_runtime().model_info
passed, err = True, ""
if not manager.report["success"]:
passed = False
err = f"Failed to run pipe for {manager.model_type} -> {manager.compile_type}"
if not msc_utils.dict_equal(model_info, expected_info):
passed = False
err = f"Model info {model_info} mismatch with expected {expected_info}"
manager.destory()
if not passed:
raise Exception(f"{err}\nReport:{json.dumps(manager.report, indent=2)}")
def _test_from_torch(
compile_type,
tools,
expected_info,
training=False,
atol=1e-1,
rtol=1e-1,
optimize_type=None,
):
torch_model = _get_torch_model("resnet50", training)
if torch_model:
if torch.cuda.is_available():
torch_model = torch_model.to(torch.device("cuda:0"))
config = _get_config(
MSCFramework.TORCH,
compile_type,
tools,
inputs=[["input_0", [1, 3, 224, 224], "float32"]],
outputs=["output"],
atol=atol,
rtol=rtol,
optimize_type=optimize_type,
)
manager = MSCManager(torch_model, config)
manager.run_pipe()
_check_manager(manager, expected_info)
def get_model_info(compile_type):
"""Get the model info"""
if compile_type == MSCFramework.TVM:
return {
"inputs": [
{"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NW"}],
"nodes": {
"total": 229,
"input": 1,
"nn.conv2d": 53,
"nn.batch_norm": 53,
"get_item": 53,
"nn.relu": 49,
"nn.max_pool2d": 1,
"add": 16,
"nn.adaptive_avg_pool2d": 1,
"reshape": 1,
"msc.linear_bias": 1,
},
}
if compile_type == MSCFramework.TENSORRT:
return {
"inputs": [
{"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": ""}],
"nodes": {"total": 2, "input": 1, "msc_tensorrt": 1},
}
raise TypeError("Unexpected compile_type " + str(compile_type))
@pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER, ToolType.TRACKER])
def test_tvm_tool(tool_type):
"""Test tools for tvm"""
tools = get_tools(tool_type)
_test_from_torch(MSCFramework.TVM, tools, get_model_info(MSCFramework.TVM), training=False)
@pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER])
def test_tvm_distill(tool_type):
"""Test tools for tvm with distiller"""
tools = get_tools(tool_type, use_distill=True)
_test_from_torch(MSCFramework.TVM, tools, get_model_info(MSCFramework.TVM), training=False)
@requires_tensorrt
@pytest.mark.parametrize(
"tool_type",
[ToolType.PRUNER, ToolType.QUANTIZER, ToolType.TRACKER],
)
def test_tensorrt_tool(tool_type):
"""Test tools for tensorrt"""
tools = get_tools(tool_type, run_type=MSCFramework.TENSORRT)
if tool_type == ToolType.QUANTIZER:
optimize_type = MSCFramework.TENSORRT
else:
optimize_type = None
_test_from_torch(
MSCFramework.TENSORRT,
tools,
get_model_info(MSCFramework.TENSORRT),
training=False,
atol=1e-1,
rtol=1e-1,
optimize_type=optimize_type,
)
@requires_tensorrt
@pytest.mark.parametrize("tool_type", [ToolType.PRUNER])
def test_tensorrt_distill(tool_type):
"""Test tools for tensorrt with distiller"""
tools = get_tools(tool_type, use_distill=True)
_test_from_torch(
MSCFramework.TENSORRT, tools, get_model_info(MSCFramework.TENSORRT), training=False
)
if __name__ == "__main__":
tvm.testing.main()