| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """The core tuning API""" |
| from typing import List, Optional |
| |
| from .builder import Builder |
| from .cost_model import CostModel |
| from .database import Database |
| from .measure_callback import MeasureCallback |
| from .runner import Runner |
| from .task_scheduler import TaskScheduler |
| from .tune_context import TuneContext |
| |
| |
| def tune_tasks( |
| *, |
| tasks: List[TuneContext], |
| task_weights: List[float], |
| work_dir: str, |
| max_trials_global: int, |
| max_trials_per_task: Optional[int] = None, |
| num_trials_per_iter: int = 64, |
| builder: Builder.BuilderType = "local", |
| runner: Runner.RunnerType = "local", |
| database: Database.DatabaseType = "json", |
| cost_model: CostModel.CostModelType = "xgb", |
| measure_callbacks: MeasureCallback.CallbackListType = "default", |
| task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", |
| module_equality: str = "structural", |
| ) -> Database: |
| """Tune a list of tasks. Using a task scheduler. |
| |
| Parameters |
| ---------- |
| tasks : List[TuneContext] |
| The list of tasks to tune. |
| task_weights : List[float] |
| The weight of each task. |
| work_dir : str |
| The working directory. |
| max_trials_global : int |
| The maximum number of trials to run globally. |
| max_trials_per_task : Optional[int] |
| The maximum number of trials to run per task. |
| num_trials_per_iter : int |
| The number of trials to run per iteration |
| builder : Builder.BuilderType |
| The builder. |
| runner : Runner.RunnerType |
| The runner. |
| database : Database.DatabaseType |
| The database. |
| cost_model : CostModel.CostModelType |
| The cost model. |
| measure_callbacks : MeasureCallback.CallbackListType |
| The measure callbacks. |
| task_scheduler : TaskScheduler.TaskSchedulerType |
| The task scheduler. |
| module_equality : Optional[str] |
| A string to specify the module equality testing and hashing method. |
| It must be one of the followings: |
| |
| - "structural": Use StructuralEqual/Hash |
| - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during equality |
| testing and hashing. |
| - "anchor-block": Apply equality testing and hashing on the anchor block extracted from |
| a given module. The "ignore-ndarray" varint is used for the extracted blocks or in |
| case no anchor block is found. For the definition of the anchor block, see |
| tir/analysis/analysis.py. |
| |
| Returns |
| ------- |
| database : Database |
| The database with all tuning records |
| """ |
| if len(tasks) == 0: |
| raise ValueError("No tasks to tune.") |
| |
| if len(tasks) != len(task_weights): |
| raise ValueError( |
| f"Length of tasks ({len(tasks)}) and task_weights ({len(task_weights)}) do not match." |
| ) |
| |
| num_cores = tasks[0].num_threads |
| |
| if max_trials_per_task is None: |
| max_trials_per_task = max_trials_global |
| if not isinstance(builder, Builder): |
| builder = Builder.create(builder, max_workers=num_cores) |
| if not isinstance(runner, Runner): |
| runner = Runner.create(runner, max_workers=num_cores) |
| if database == "json": |
| database = Database.create(database, work_dir=work_dir, module_equality=module_equality) |
| elif not isinstance(database, Database): |
| database = Database.create(database, module_equality=module_equality) |
| if not isinstance(cost_model, CostModel): |
| cost_model = CostModel.create(cost_model, num_tuning_cores=num_cores, tree_method="auto") |
| if isinstance(measure_callbacks, MeasureCallback): |
| measure_callbacks = [measure_callbacks] |
| elif measure_callbacks == "default": |
| measure_callbacks = MeasureCallback.create(measure_callbacks) |
| if not isinstance(task_scheduler, TaskScheduler): |
| task_scheduler = TaskScheduler.create(task_scheduler) |
| task_scheduler.tune( |
| tasks=tasks, |
| task_weights=task_weights, |
| max_trials_global=max_trials_global, |
| max_trials_per_task=max_trials_per_task, |
| num_trials_per_iter=num_trials_per_iter, |
| builder=builder, |
| runner=runner, |
| measure_callbacks=measure_callbacks, |
| database=database, |
| cost_model=cost_model, |
| ) |
| return database |