blob: ce5b888f25dc73f3476940cd69adaba785f843de [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import platform
import pytest
import os
from unittest import mock
from os import path
from pathlib import Path
import tvm
import tvm.testing
from tvm import autotvm, auto_scheduler
from tvm.driver import tvmc
from tvm.driver.tvmc.autotuner import filter_tasks, gen_task_list
def _get_tasks(model):
tvmc_model = tvmc.frontends.load_model(model)
return tvmc.autotuner.autotvm_get_tuning_tasks(tvmc_model.mod, tvmc_model.params, "llvm")
def _get_measure_options():
return autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func="default"), runner="local"
)
def _tuner_test_helper(model, tuner_name, tmpdir_name, early_stopping=1, prior_records=None):
tvmc_model = tvmc.frontends.load_model(model)
log_file = os.path.join(tmpdir_name, "log_{}.txt".format(tuner_name))
tvmc.tune(
tvmc_model,
target="llvm",
tuning_records=log_file,
prior_records=prior_records,
tuner=tuner_name,
trials=4,
early_stopping=early_stopping,
)
# testing whether the log file was produced
assert path.exists(log_file), "tuning log file should exist"
with autotvm.apply_history_best(log_file) as best:
assert isinstance(
best, autotvm.task.dispatcher.ApplyHistoryBest
), "unable to load the best results of tuning"
return log_file
def test_get_tuning_tasks(onnx_mnist):
pytest.importorskip("onnx")
sut = _get_tasks(onnx_mnist)
expected_task_type = autotvm.task.Task
assert type(sut) is list
assert len(sut) > 0
assert all([type(x) is expected_task_type for x in sut]) is True
@pytest.mark.skipif(
platform.machine() == "aarch64",
reason="Currently failing on AArch64 - see https://github.com/apache/tvm/issues/10673",
)
def test_tune_tasks__tuner__xgb(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
_tuner_test_helper(onnx_mnist, "xgb", tmpdir_name)
def test_tune_tasks__tuner__xgb_knob(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
_tuner_test_helper(onnx_mnist, "xgb_knob", tmpdir_name)
def test_tune_tasks__tuner__ga(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
_tuner_test_helper(onnx_mnist, "ga", tmpdir_name)
def test_tune_tasks__tuner__random(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
_tuner_test_helper(onnx_mnist, "random", tmpdir_name)
def test_tune_tasks__tuner__gridsearch(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
_tuner_test_helper(onnx_mnist, "gridsearch", tmpdir_name)
def test_tune_tasks__tuner__gridsearch__tuning_records(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
output_log_phase_1 = _tuner_test_helper(onnx_mnist, "gridsearch", tmpdir_name)
# Exercises transfer learning by making sure a previous log exists
_tuner_test_helper(onnx_mnist, "gridsearch", tmpdir_name, prior_records=output_log_phase_1)
def test_tune_tasks__tuner__ga__empty_tasks(tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
log_file = os.path.join(tmpdir_name, "log_{}.txt".format("ga"))
tvmc.autotuner.tune_tasks(
tasks=[],
log_file=log_file,
measure_option=_get_measure_options(),
tuner="ga",
trials=1,
early_stopping=1,
)
def test_tune_tasks__tuner__xgb__no_early_stopping(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
_tuner_test_helper(onnx_mnist, "xgb", tmpdir_name, early_stopping=None)
@pytest.mark.skipif(
platform.machine() == "aarch64",
reason="Currently failing on AArch64 - see https://github.com/apache/tvm/issues/10673",
)
def test_tune_tasks__tuner__xgb__no_tuning_records(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
_tuner_test_helper(onnx_mnist, "xgb", tmpdir_name, prior_records=None)
def test_tune_tasks__invalid_tuner(onnx_mnist, tmpdir_factory):
pytest.importorskip("onnx")
tasks = _get_tasks(onnx_mnist)
log_file = os.path.join(tmpdir_factory.mktemp("data"), "log2.txt")
with pytest.raises(tvmc.TVMCException):
tvmc.autotuner.tune_tasks(tasks, log_file, _get_measure_options(), "invalid_tuner", 1, 1)
@mock.patch("tvm.driver.tvmc.autotuner.auto_scheduler.HardwareParams", return_value=None)
@mock.patch("tvm.driver.tvmc.autotuner.tune_model", return_value=None)
@mock.patch("tvm.driver.tvmc.frontends.load_model", return_value=None)
def test_tune_rpc_tracker_parsing(mock_load_model, mock_tune_model, mock_auto_scheduler):
cli_args = mock.MagicMock()
cli_args.rpc_tracker = "10.0.0.1:9999"
# FILE is not used but it's set to a valid value here to avoid it being set
# by mock to a MagicMock class, which won't pass the checks for valid FILE.
fake_input_file = "./fake_input_file.tflite"
Path(fake_input_file).touch()
cli_args.FILE = fake_input_file
tvmc.autotuner.drive_tune(cli_args)
os.remove(fake_input_file)
mock_tune_model.assert_called_once()
# inspect the mock call, to search for specific arguments
_, _, kwargs = mock_tune_model.mock_calls[0]
assert "hostname" in kwargs
assert "10.0.0.1" == kwargs["hostname"]
assert "port" in kwargs
assert 9999 == kwargs["port"]
@mock.patch("tvm.transform.PassContext", return_value=tvm.transform.PassContext())
def test_autotune_pass_context(mock_pc, onnx_mnist, tmpdir_factory):
"""
Check that the pass context while tuning is as expected.
"""
pytest.importorskip("onnx")
tmpdir_name = tmpdir_factory.mktemp("data")
_tuner_test_helper(onnx_mnist, "gridsearch", tmpdir_name)
# AutoTVM overrides the pass context later in the pipeline to disable AlterOpLayout
assert mock_pc.call_count == 2
assert mock_pc.call_args_list[0][1]["opt_level"] == 3
def test_filter_tasks_valid():
filter_tasks(list(range(10)), "list") == ([], True)
filter_tasks(list(range(10)), "help") == ([], True)
filter_tasks(list(range(10)), "all") == ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], False)
filter_tasks(list(range(10)), "5") == ([5], False)
filter_tasks(list(range(10)), "1-5") == ([1, 2, 3, 4, 5], False)
filter_tasks(list(range(10)), "-5") == ([0, 1, 2, 3, 4, 5], False)
filter_tasks(list(range(10)), "6-") == ([6, 7, 8, 9], False)
filter_tasks(list(range(10)), "0,1-3,all") == ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], False)
filter_tasks(list(range(10)), "0,4-5,9,list") == ([0, 4, 5, 9], True)
@pytest.mark.parametrize(
"value,err_msg",
[
("10", "Task index out of range"),
("5,10", "Task index out of range"),
("1-10", "Right-hand side expression out of range"),
("-10", "Right-hand side expression out of range"),
("-", "Missing lhs or rhs for range expression"),
("-10-", "Malformed range expression"),
("--", "Malformed range expression"),
],
)
def test_filter_tasks_invalid(value, err_msg):
with pytest.raises(AssertionError, match=err_msg):
filter_tasks(list(range(10)), value)
@pytest.mark.parametrize(
"enable_autoscheduler,expected",
[
(
False,
"""Available Tasks for tuning:
0. Task(func_name=taskA, args=[], kwargs={}, workload=('taskA',)) (len=?)
1. Task(func_name=taskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBta... (len=?)
2. Task(func_name=taskC, args=[], kwargs={}, workload=('taskC',)) (len=?)""",
),
(
True,
"""Available Tasks for tuning:
0. taskA
1. taskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBta...
2. Unnamed""",
),
],
)
def test_print_task_list(enable_autoscheduler, expected):
if enable_autoscheduler:
auto_scheduler.search_task.TASK_INPUT_BUFFER_TABLE.clear()
N = 64
target = "llvm"
test_input_0 = tvm.runtime.ndarray.empty((64, 64))
test_input_1 = tvm.runtime.ndarray.empty((10, 20))
test_input_2 = tvm.runtime.ndarray.empty((30, 40, 50))
task_inputs = {
"test_input_0": test_input_0,
"test_input_1": test_input_1,
"test_input_2": test_input_2,
}
task1 = auto_scheduler.SearchTask(
func="matmul_auto_scheduler_test",
args=(N, N, N),
target=target,
task_inputs=task_inputs,
task_inputs_overwrite=True,
desc="taskA",
)
task2 = auto_scheduler.SearchTask(
func="matmul_auto_scheduler_test",
args=(N, N, N),
target=target,
task_inputs=task_inputs,
task_inputs_overwrite=True,
desc="taskB" * 20, # very long name
)
task3 = auto_scheduler.SearchTask(
func="matmul_auto_scheduler_test",
args=(N, N, N),
target=target,
task_inputs=task_inputs,
task_inputs_overwrite=True,
# missing description
)
else:
task1 = autotvm.task.Task("taskA", [])
task2 = autotvm.task.Task("taskB" * 20, []) # very long name
task3 = autotvm.task.Task("taskC", [])
tasks = [task1, task2, task3]
out = gen_task_list(tasks, enable_autoscheduler)
assert out == expected
if __name__ == "__main__":
tvm.testing.main()