blob: a3a3f52d39ace6625724ec441973fc3e82a96fd0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import platform
import pytest
import os
from os import path
from tvm import auto_scheduler
from tvm.driver import tvmc
def _get_tasks(model):
tvmc_model = tvmc.frontends.load_model(model)
tasks, weights = tvmc.autotuner.autoscheduler_get_tuning_tasks(
tvmc_model.mod, tvmc_model.params, "llvm"
)
return (tasks, weights)
def _autoscheduler_test_helper(model, tmpdir_name, early_stopping=1, prior_records=None):
tvmc_model = tvmc.frontends.load_model(model)
log_file = os.path.join(tmpdir_name, "autoscheduler.json")
hardware_params = auto_scheduler.HardwareParams(num_cores=4, target="llvm")
tvmc.tune(
tvmc_model,
target="llvm",
tuning_records=log_file,
prior_records=prior_records,
early_stopping=early_stopping,
enable_autoscheduler=True,
trials=2,
hardware_params=hardware_params,
)
# testing whether the log file was produced
assert path.exists(log_file), "autoscheduler log file should exist"
with auto_scheduler.ApplyHistoryBest(log_file) as best:
assert isinstance(
best, auto_scheduler.dispatcher.ApplyHistoryBest
), "unable to load the best results of tuning"
return log_file
def test_get_tuning_tasks(keras_simple):
pytest.importorskip("tensorflow")
tasks, weights = _get_tasks(keras_simple)
expected_task_type = auto_scheduler.SearchTask
assert type(tasks) is list
assert len(tasks) > 0
assert all([type(x) is expected_task_type for x in tasks]) is True
@pytest.mark.skipif(
platform.machine() == "aarch64",
reason="Currently failing on AArch64 - see https://github.com/apache/tvm/issues/10673",
)
def test_tune_tasks(keras_simple, tmpdir_factory):
pytest.importorskip("tensorflow")
tmpdir_name = tmpdir_factory.mktemp("data")
_autoscheduler_test_helper(keras_simple, tmpdir_name)
@pytest.mark.skipif(
platform.machine() == "aarch64",
reason="Currently failing on AArch64 - see https://github.com/apache/tvm/issues/10673",
)
def test_tune_tasks__tuning_records(keras_simple, tmpdir_factory):
pytest.importorskip("tensorflow")
tmpdir_name = tmpdir_factory.mktemp("data")
output_log_phase_1 = _autoscheduler_test_helper(keras_simple, tmpdir_name)
# Exercises transfer learning by making sure a previous log exists
_autoscheduler_test_helper(keras_simple, tmpdir_name, prior_records=output_log_phase_1)
@pytest.mark.skipif(
platform.machine() == "aarch64",
reason="Currently failing on AArch64 - see https://github.com/apache/tvm/issues/10673",
)
def test_tune_tasks__no_early_stopping(keras_simple, tmpdir_factory):
pytest.importorskip("tensorflow")
tmpdir_name = tmpdir_factory.mktemp("data")
_autoscheduler_test_helper(keras_simple, tmpdir_name, early_stopping=None)