blob: c7ae19678a307a047d6326ced94b78c7b260197e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Serialization and other I/O support for measurement records (tuning logs). """
import numpy as np
import tvm._ffi
from tvm.runtime import Object
from .measure import MeasureCallback, MeasureErrorNo
from . import _ffi_api
@tvm._ffi.register_object("auto_scheduler.RecordToFile")
class RecordToFile(MeasureCallback):
"""
A measurement callback that writes measurement records into a file.
Parameters
----------
filename : str
File name for this callback to write log to.
"""
def __init__(self, filename="auto_scheduler_tuning.json"):
self.__init_handle_by_constructor__(_ffi_api.RecordToFile, filename)
@tvm._ffi.register_object("auto_scheduler.RecordReader")
class RecordReader(Object):
"""
Reader of the json log file.
Parameters
----------
filename : str = "auto_scheduler_tuning.json"
File name for this reader to load log from.
"""
def __init__(self, filename="auto_scheduler_tuning.json"):
self.__init_handle_by_constructor__(_ffi_api.RecordReader, filename)
def read_lines(self, max_lines=None, skip_lines=0):
"""Read multiple lines from the log file.
Parameters
----------
max_lines : Optional[int]
The maximum number of lines. None to read all lines.
skip_lines : int = 0
Skip the first n lines.
Returns
-------
inputs : List[auto_scheduler.measure.MeasureInput]
The MeasureInputs loaded from the log file.
results : List[auto_scheduler.measure.MeasureResult]
The MeasureResults loaded from the log file.
"""
inputs, results = _ffi_api.RecordReaderReadLines(
self, max_lines if max_lines else -1, skip_lines
)
return inputs, results
def __iter__(self):
while True:
ret = _ffi_api.RecordReaderReadNext(self)
if not ret:
break
yield ret[0], ret[1] # (input, result)
def load_records(filename):
"""
Load measurement records from a file.
Parameters
----------
filename : str
File name to load log from.
Returns
-------
logs : List[auto_scheduler.measure.MeasureInput, auto_scheduler.measure.MeasureResult]
"""
return zip(*RecordReader(filename).read_lines())
def save_records(filename, inputs, results):
"""
Append measure records to file.
Parameters
----------
filename : str
File name to write log to.
inputs: List[MeasureInputs]
The MeasureInputs to be written.
results: List[MeasureResults]
The MeasureResults to be written.
"""
_ffi_api.SaveRecords(filename, inputs, results)
def load_best(filename, workload_key=None, target=None):
"""Return the best measurement pair form a log file. This may return none results if
there is no legal measure pair with the specified workload_key/target found from the log file.
Parameters
----------
filename : str
File name to load log from.
workload_key : Optional[str]
The workload key of the compute declaration.
With `None`, this returns the best measure pair of all workloads.
target : Optional[tvm.target.Target]
The target device.
With `None`, this returns the best measure pair of all target devices.
Returns
-------
input : auto_scheduler.measure.MeasureInput
The best State's MeasureInput from this log fine.
result : auto_scheduler.measure.MeasureResult
The best State's MeasureResult from this log fine.
"""
log_reader = RecordReader(filename)
best_cost = 1e30
best_inp = None
best_res = None
for inp, res in log_reader:
if res.error_no != MeasureErrorNo.NO_ERROR:
continue
if workload_key and inp.task.workload_key != workload_key:
continue
if target and inp.task.target.kind.name != target.kind.name:
continue
costs = [v.value for v in res.costs]
cost = np.mean(costs)
if cost < best_cost:
best_cost = cost
best_inp = inp
best_res = res
return best_inp, best_res