| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring |
| """Test Meta Schedule Database""" |
| import os.path as osp |
| import tempfile |
| from typing import Callable, List, Optional |
| |
| import pytest |
| import tvm |
| import tvm.testing |
| from tvm import meta_schedule as ms |
| from tvm import tir |
| from tvm.ir.module import IRModule |
| from tvm.meta_schedule.database import TuningRecord, Workload |
| from tvm.script import tir as T |
| from tvm.target import Target |
| from tvm.tir import Schedule |
| |
| |
| # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument |
| # fmt: off |
| @tvm.script.ir_module |
| class Matmul: |
| @T.prim_func |
| def main(a: T.handle, b: T.handle, c: T.handle) -> None: |
| T.func_attr({"global_symbol": "main"}) |
| A = T.match_buffer(a, (1024, 1024), "float32") |
| B = T.match_buffer(b, (1024, 1024), "float32") |
| C = T.match_buffer(c, (1024, 1024), "float32") |
| for i, j, k in T.grid(1024, 1024, 1024): |
| with T.block("matmul"): |
| vi, vj, vk = T.axis.remap("SSR", [i, j, k]) |
| with T.init(): |
| C[vi, vj] = 0.0 |
| C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] |
| |
| |
| @tvm.script.ir_module |
| class MatmulRelu: |
| @T.prim_func |
| def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-self-argument |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| A = T.match_buffer(a, (16, 16), "float32") |
| B = T.match_buffer(b, (16, 16), "float32") |
| D = T.match_buffer(d, (16, 16), "float32") |
| C = T.alloc_buffer((16, 16), "float32") |
| for i, j, k in T.grid(16, 16, 16): |
| with T.block("matmul"): |
| vi, vj, vk = T.axis.remap("SSR", [i, j, k]) |
| with T.init(): |
| C[vi, vj] = 0.0 |
| C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] |
| for i, j in T.grid(16, 16): |
| with T.block("relu"): |
| vi, vj = T.axis.remap("SS", [i, j]) |
| D[vi, vj] = T.max(C[vi, vj], 0.0) |
| |
| |
| # fmt: on |
| # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument |
| |
| |
| def _schedule_matmul(sch: Schedule): |
| block = sch.get_block("matmul") |
| i, j, k = sch.get_loops(block=block) |
| i_tiles = [1, 1, 2, 512] |
| j_tiles = [1, 512, 1, 2] |
| k_tiles = [256, 4] |
| i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=i_tiles) |
| j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=j_tiles) |
| k_0, k_1 = sch.split(loop=k, factors=k_tiles) |
| sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) |
| |
| |
| def _create_schedule(mod: IRModule, sch_fn: Callable[[Schedule], None]) -> Schedule: |
| sch = tir.Schedule(mod=mod, debug_mask="all") |
| sch_fn(sch) |
| return sch |
| |
| |
| def _create_tmp_database(tmpdir: str, mod_eq: str = "structural") -> ms.database.JSONDatabase: |
| path_workload = osp.join(tmpdir, "workloads.json") |
| path_tuning_record = osp.join(tmpdir, "tuning_records.json") |
| return ms.database.JSONDatabase(path_workload, path_tuning_record, module_equality=mod_eq) |
| |
| |
| def _equal_record(a: ms.database.TuningRecord, b: ms.database.TuningRecord): |
| assert str(a.trace) == str(b.trace) |
| assert str(a.run_secs) == str(b.run_secs) |
| # AWAIT(@zxybazh): change to export after fixing "(bool)0" |
| assert str(a.target) == str(b.target) |
| tvm.ir.assert_structural_equal(a.workload.mod, b.workload.mod) |
| for arg0, arg1 in zip(a.args_info, b.args_info): |
| assert str(arg0.as_json()) == str(arg1.as_json()) |
| |
| |
| @ms.utils.derived_object |
| class PyMemoryDatabaseDefault(ms.database.PyDatabase): |
| def __init__(self): |
| super().__init__() |
| self.tuning_records_: List[TuningRecord] = [] |
| self.workloads_: List[Workload] = [] |
| |
| def has_workload(self, mod: IRModule) -> bool: |
| for workload in self.workloads_: |
| if tvm.ir.structural_equal(mod, workload.mod): |
| return True |
| |
| def commit_workload(self, mod: IRModule) -> ms.database.Workload: |
| if self.has_workload(mod): |
| for workload in self.workloads_: |
| if tvm.ir.structural_equal(mod, workload.mod): |
| return workload |
| else: |
| workload = ms.database.Workload(mod) |
| self.workloads_.append(workload) |
| return workload |
| |
| def commit_tuning_record(self, record: TuningRecord) -> None: |
| self.tuning_records_.append(record) |
| |
| def get_all_tuning_records(self) -> List[TuningRecord]: |
| return self.tuning_records_ |
| |
| def get_top_k(self, workload: ms.database.Workload, top_k: int) -> List[TuningRecord]: |
| return sorted( |
| list( |
| filter( |
| lambda x: tvm.ir.structural_equal(workload.mod, x.workload.mod), |
| self.tuning_records_, |
| ) |
| ), |
| key=lambda x: sum(x.run_secs) / len(x.run_secs) if x.run_secs else 1e9, |
| )[:top_k] |
| |
| def __len__(self) -> int: |
| return len(self.tuning_records_) |
| |
| |
| @ms.utils.derived_object |
| class PyMemoryDatabaseOverride(ms.database.PyDatabase): |
| def __init__(self): |
| super().__init__() |
| self.tuning_records_: List[TuningRecord] = [] |
| self.workloads_: List[Workload] = [] |
| |
| def has_workload(self, mod: IRModule) -> bool: |
| for workload in self.workloads_: |
| if tvm.ir.structural_equal(mod, workload.mod): |
| return True |
| |
| def commit_workload(self, mod: IRModule) -> ms.database.Workload: |
| if self.has_workload(mod): |
| for workload in self.workloads_: |
| if tvm.ir.structural_equal(mod, workload.mod): |
| return workload |
| else: |
| workload = ms.database.Workload(mod) |
| self.workloads_.append(workload) |
| return workload |
| |
| def commit_tuning_record(self, record: TuningRecord) -> None: |
| self.tuning_records_.append(record) |
| |
| def get_all_tuning_records(self) -> List[TuningRecord]: |
| return self.tuning_records_ |
| |
| def get_top_k(self, workload: ms.database.Workload, top_k: int) -> List[TuningRecord]: |
| return sorted( |
| list( |
| filter( |
| lambda x: tvm.ir.structural_equal(workload.mod, x.workload.mod), |
| self.tuning_records_, |
| ) |
| ), |
| key=lambda x: sum(x.run_secs) / len(x.run_secs) if x.run_secs else 1e9, |
| )[:top_k] |
| |
| def __len__(self) -> int: |
| return len(self.tuning_records_) |
| |
| def query_tuning_record( |
| self, mod: IRModule, target: Target, workload_name: Optional[str] = None |
| ) -> Optional[TuningRecord]: |
| if self.has_workload(mod): |
| records = self.get_top_k(self.commit_workload(mod), 2) |
| if len(records) == 1: |
| return records[0] |
| elif len(records) == 2: |
| return records[1] # return the 2nd best if there are two records |
| return None |
| |
| def query_schedule( |
| self, mod: IRModule, target: Target, workload_name: Optional[str] = None |
| ) -> Optional[Schedule]: |
| record = self.query_tuning_record(mod, target, workload_name) |
| if record is not None: |
| sch = Schedule(record.workload.mod) |
| record.trace.apply_to_schedule(sch, remove_postproc=False) |
| return sch |
| return None |
| |
| def query_ir_module( |
| self, mod: IRModule, target: Target, workload_name: Optional[str] = None |
| ) -> Optional[IRModule]: |
| record = self.query_tuning_record(mod, target, workload_name) |
| if record is not None: |
| sch = Schedule(record.workload.mod) |
| record.trace.apply_to_schedule(sch, remove_postproc=False) |
| return sch.mod |
| return None |
| |
| |
| def test_meta_schedule_tuning_record_round_trip(): |
| mod: IRModule = Matmul |
| with tempfile.TemporaryDirectory() as tmpdir: |
| database = _create_tmp_database(tmpdir) |
| workload = database.commit_workload(mod) |
| record = ms.database.TuningRecord( |
| _create_schedule(mod, _schedule_matmul).trace, |
| workload, |
| [T.float32(1.5), T.float32(2.5), T.float32(1.8)], |
| tvm.target.Target("llvm"), |
| ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), |
| ) |
| database.commit_tuning_record(record) |
| new_record = ms.database.TuningRecord.from_json(record.as_json(), workload) |
| _equal_record(record, new_record) |
| |
| |
| def test_meta_schedule_database_create(): |
| with tempfile.TemporaryDirectory() as tmpdir: |
| database = _create_tmp_database(tmpdir) |
| assert osp.exists(database.path_workload) |
| assert osp.exists(database.path_tuning_record) |
| |
| |
| def test_meta_schedule_database_has_workload(): |
| mod: IRModule = Matmul |
| missing_mod: IRModule = MatmulRelu |
| with tempfile.TemporaryDirectory() as tmpdir: |
| database = _create_tmp_database(tmpdir) |
| workload = database.commit_workload(mod) |
| record = ms.database.TuningRecord( |
| _create_schedule(mod, _schedule_matmul).trace, |
| workload, |
| [1.5, 2.5, 1.8], |
| tvm.target.Target("llvm"), |
| ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), |
| ) |
| database.commit_tuning_record(record) |
| assert len(database) == 1 |
| assert database.has_workload(mod) |
| assert not database.has_workload(missing_mod) |
| |
| |
| def test_meta_schedule_database_add_entry(): |
| mod: IRModule = Matmul |
| with tempfile.TemporaryDirectory() as tmpdir: |
| database = _create_tmp_database(tmpdir) |
| workload = database.commit_workload(mod) |
| record = ms.database.TuningRecord( |
| _create_schedule(mod, _schedule_matmul).trace, |
| workload, |
| [1.5, 2.5, 1.8], |
| tvm.target.Target("llvm"), |
| ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), |
| ) |
| database.commit_tuning_record(record) |
| assert len(database) == 1 |
| (ret,) = database.get_top_k(workload, 3) |
| _equal_record(ret, record) |
| |
| |
| def test_meta_schedule_database_missing(): |
| mod: IRModule = Matmul |
| mod_2: IRModule = MatmulRelu |
| with tempfile.TemporaryDirectory() as tmpdir: |
| database = _create_tmp_database(tmpdir) |
| workload = database.commit_workload(mod) |
| workload_2 = database.commit_workload(mod_2) |
| record = ms.database.TuningRecord( |
| _create_schedule(mod, _schedule_matmul).trace, |
| workload, |
| [1.5, 2.5, 1.8], |
| tvm.target.Target("llvm"), |
| ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), |
| ) |
| database.commit_tuning_record(record) |
| ret = database.get_top_k(workload_2, 3) |
| assert len(ret) == 0 |
| |
| |
| def test_meta_schedule_database_sorting(): |
| mod: IRModule = Matmul |
| with tempfile.TemporaryDirectory() as tmpdir: |
| database = _create_tmp_database(tmpdir) |
| token = database.commit_workload(mod) |
| trace = _create_schedule(mod, _schedule_matmul).trace |
| records = [ |
| ms.database.TuningRecord( |
| trace, |
| token, |
| [7.0, 8.0, 9.0], |
| tvm.target.Target("llvm"), |
| ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), |
| ), |
| ms.database.TuningRecord( |
| trace, |
| token, |
| [1.0, 2.0, 3.0], |
| tvm.target.Target("llvm"), |
| ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), |
| ), |
| ms.database.TuningRecord( |
| trace, |
| token, |
| [4.0, 5.0, 6.0], |
| tvm.target.Target("llvm"), |
| ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), |
| ), |
| ms.database.TuningRecord( |
| trace, |
| token, |
| [1.1, 1.2, 600.0], |
| tvm.target.Target("llvm"), |
| ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), |
| ), |
| ms.database.TuningRecord( |
| trace, |
| token, |
| [1.0, 100.0, 6.0], |
| tvm.target.Target("llvm"), |
| ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), |
| ), |
| ms.database.TuningRecord( |
| trace, |
| token, |
| [4.0, 9.0, 8.0], |
| tvm.target.Target("llvm"), |
| ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), |
| ), |
| ] |
| for record in records: |
| database.commit_tuning_record(record) |
| ret = database.get_top_k(token, 2) |
| assert len(ret) == 2 |
| try: |
| _equal_record(ret[0], records[2]) |
| _equal_record(ret[1], records[1]) |
| except AssertionError: |
| _equal_record(ret[0], records[1]) |
| _equal_record(ret[1], records[2]) |
| |
| |
| def test_meta_schedule_database_reload(): |
| mod: IRModule = Matmul |
| with tempfile.TemporaryDirectory() as tmpdir: |
| database = _create_tmp_database(tmpdir) |
| token = database.commit_workload(mod) |
| trace = _create_schedule(mod, _schedule_matmul).trace |
| records = [ |
| ms.database.TuningRecord( |
| trace, |
| token, |
| [7.0, 8.0, 9.0], |
| tvm.target.Target("llvm"), |
| ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), |
| ), |
| ms.database.TuningRecord( |
| trace, |
| token, |
| [1.0, 2.0, 3.0], |
| tvm.target.Target("llvm"), |
| ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), |
| ), |
| ms.database.TuningRecord( |
| trace, |
| token, |
| [4.0, 5.0, 6.0], |
| tvm.target.Target("llvm"), |
| ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), |
| ), |
| ] |
| for record in records: |
| database.commit_tuning_record(record) |
| new_database = ms.database.JSONDatabase( |
| path_workload=database.path_workload, |
| path_tuning_record=database.path_tuning_record, |
| ) |
| token = new_database.commit_workload(mod) |
| ret = new_database.get_top_k(token, 2) |
| assert len(ret) == 2 |
| try: |
| _equal_record(ret[0], records[2]) |
| _equal_record(ret[1], records[1]) |
| except AssertionError: |
| _equal_record(ret[0], records[1]) |
| _equal_record(ret[1], records[2]) |
| |
| |
| def test_meta_schedule_database_union(): |
| mod: IRModule = Matmul |
| target = tvm.target.Target("llvm") |
| arg_info = ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]) |
| db_1 = ms.database.MemoryDatabase() |
| db_2 = ms.database.MemoryDatabase() |
| trace = _create_schedule(mod, _schedule_matmul).trace |
| |
| def query(db): # pylint: disable=invalid-name |
| return db.query_tuning_record(mod=mod, target=target, workload_name="main").run_secs |
| |
| def commit_record(db, run_sec): # pylint: disable=invalid-name |
| db.commit_tuning_record( |
| ms.database.TuningRecord( |
| trace, |
| workload=db.commit_workload(mod), |
| run_secs=[run_sec], |
| target=target, |
| args_info=arg_info, |
| ) |
| ) |
| |
| commit_record(db_1, 1.0) |
| (run_sec,) = query(db_1) |
| assert run_sec.value == 1.0 |
| |
| commit_record(db_2, 0.5) |
| (run_sec,) = query(db_2) |
| assert run_sec.value == 0.5 |
| |
| (run_secs,) = query(ms.database.UnionDatabase(db_1, db_2)) |
| assert run_secs.value == 0.5 |
| |
| (run_secs,) = query(ms.database.OrderedUnionDatabase(db_1, db_2)) |
| assert run_secs.value == 1.0 |
| |
| |
| def test_meta_schedule_pydatabase_default_query(): |
| mod: IRModule = Matmul |
| target = tvm.target.Target("llvm") |
| arg_info = ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]) |
| db = PyMemoryDatabaseDefault() # pylint: disable=invalid-name |
| sch = _create_schedule(mod, _schedule_matmul) |
| trace = sch.trace |
| |
| def query(db, mod, target, kind): # pylint: disable=invalid-name |
| return db.query(mod=mod, target=target, workload_name="main", kind=kind) |
| |
| def commit_record(trace, db, run_sec): # pylint: disable=invalid-name |
| db.commit_tuning_record( |
| ms.database.TuningRecord( |
| trace, |
| workload=db.commit_workload(mod), |
| run_secs=[run_sec], |
| target=target, |
| args_info=arg_info, |
| ) |
| ) |
| |
| commit_record(trace, db, 1.0) |
| record = query(db, mod, target, "record") |
| assert record is not None and record.run_secs[0].value == 1.0 |
| sch_res = query(db, mod, target, "schedule") |
| assert sch_res is not None and tvm.ir.structural_equal(sch_res.mod, sch.mod) |
| mod_res = query(db, mod, target, "ir_module") |
| assert mod_res is not None and tvm.ir.structural_equal(mod_res, sch.mod) |
| |
| commit_record(Schedule(mod).trace, db, 0.2) # Empty Trace |
| record = query(db, mod, target, "record") |
| assert record is not None and record.run_secs[0].value == 0.2 |
| sch_res = query(db, mod, target, "schedule") |
| assert sch_res is not None and tvm.ir.structural_equal(sch_res.mod, mod) |
| mod_res = query(db, mod, target, "ir_module") |
| assert mod_res is not None and tvm.ir.structural_equal(mod_res, mod) |
| |
| |
| def test_meta_schedule_pydatabase_override_query(): |
| mod: IRModule = Matmul |
| target = tvm.target.Target("llvm") |
| arg_info = ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]) |
| db = PyMemoryDatabaseOverride() # pylint: disable=invalid-name |
| sch = _create_schedule(mod, _schedule_matmul) |
| trace = sch.trace |
| |
| def query(db, mod, target, kind): # pylint: disable=invalid-name |
| return db.query(mod=mod, target=target, workload_name="main", kind=kind) |
| |
| def commit_record(trace, db, run_sec): # pylint: disable=invalid-name |
| db.commit_tuning_record( |
| ms.database.TuningRecord( |
| trace, |
| workload=db.commit_workload(mod), |
| run_secs=[run_sec], |
| target=target, |
| args_info=arg_info, |
| ) |
| ) |
| |
| commit_record(trace, db, 1.14) |
| record = query(db, mod, target, "record") |
| assert record is not None and record.run_secs[0].value == 1.14 |
| sch_res = query(db, mod, target, "schedule") |
| assert sch_res is not None and tvm.ir.structural_equal(sch_res.mod, sch.mod) |
| mod_res = query(db, mod, target, "ir_module") |
| assert mod_res is not None and tvm.ir.structural_equal(mod_res, sch.mod) |
| |
| commit_record(Schedule(mod).trace, db, 0.514) # Empty Trace |
| record = query(db, mod, target, "record") |
| assert record is not None and record.run_secs[0].value == 1.14 # Override to 2nd best |
| sch_res = query(db, mod, target, "schedule") |
| assert sch_res is not None and tvm.ir.structural_equal(sch_res.mod, sch.mod) |
| mod_res = query(db, mod, target, "ir_module") |
| assert mod_res is not None and tvm.ir.structural_equal(mod_res, sch.mod) |
| |
| |
| def test_meta_schedule_pydatabase_current(): |
| db = PyMemoryDatabaseDefault() # pylint: disable=invalid-name |
| with db: # pylint: disable=not-context-manager |
| assert ms.database.Database.current() == db |
| |
| |
| def call_get_top_k(run_secs_list, database, k): |
| mod: IRModule = Matmul |
| workload = database.commit_workload(mod) |
| for run_secs in run_secs_list: |
| record = ms.database.TuningRecord( |
| _create_schedule(mod, _schedule_matmul).trace, |
| workload, |
| run_secs, |
| tvm.target.Target("llvm"), |
| ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]), |
| ) |
| database.commit_tuning_record(record) |
| return [[v.value for v in record.run_secs] for record in database.get_top_k(workload, k)] |
| |
| |
| @pytest.mark.parametrize( |
| "k,expected", |
| [ |
| (0, []), |
| (1, [[0.0, 2.0]]), |
| (4, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]), |
| (5, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]), |
| ], |
| ) |
| def test_memory_database_get_top_k(k, expected): |
| run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0], [3.0, 1e10], [1e10]] |
| database = ms.database.MemoryDatabase() |
| result = call_get_top_k(run_secs_list, database, k) |
| assert result == expected |
| |
| |
| @pytest.mark.parametrize( |
| "k,expected", |
| [ |
| (0, []), |
| (4, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]), |
| (5, [[0.0, 2.0], [2.0], [1.5, 4.5], [3.0, 1e10]]), |
| ], |
| ) |
| def test_json_database_get_top_k(k, expected): |
| run_secs_list = [[1.5, 4.5], [], [0.0, 2.0], None, [2.0], [3.0, 1e10], [1e10]] |
| with tempfile.TemporaryDirectory() as tmpdir: |
| database = _create_tmp_database(tmpdir) |
| result = call_get_top_k(run_secs_list, database, k) |
| assert result == expected |
| |
| |
| def MatmulPrimFunc() -> IRModule: |
| return Matmul |
| |
| |
| @pytest.mark.parametrize("f_mod", [MatmulPrimFunc]) |
| @pytest.mark.parametrize("mod_eq", ["structural", "ignore-tensor", "anchor-block"]) |
| def test_json_database_commit_workload(f_mod, mod_eq): |
| mod: IRModule = f_mod() |
| with tempfile.TemporaryDirectory() as tmpdir: |
| database = _create_tmp_database(tmpdir, mod_eq) |
| database.commit_workload(mod) |
| |
| |
| @pytest.mark.parametrize("f_mod", [MatmulPrimFunc]) |
| @pytest.mark.parametrize("mod_eq", ["structural", "ignore-tensor", "anchor-block"]) |
| def test_memory_database_commit_workload(f_mod, mod_eq): |
| mod: IRModule = f_mod() |
| database = ms.database.MemoryDatabase(module_equality=mod_eq) |
| database.commit_workload(mod) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |