blob: 40ee24e077b4501bb8b980041328d4ce487bccfa [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=consider-using-enumerate,invalid-name
"""Namespace of callback utilities of AutoTVM"""
import sys
import time
import logging
import numpy as np
from .. import record
from ..utils import format_si_prefix
logger = logging.getLogger("autotvm")
def log_to_file(file_out, protocol="json"):
"""Log the tuning records into file.
The rows of the log are stored in the format of autotvm.record.encode.
Parameters
----------
file_out : File or str
The file to log to.
protocol: str, optional
The log protocol. Can be 'json' or 'pickle'
Returns
-------
callback : callable
Callback function to do the logging.
"""
def _callback(_, inputs, results):
"""Callback implementation"""
if isinstance(file_out, str):
with open(file_out, "a") as f:
for inp, result in zip(inputs, results):
f.write(record.encode(inp, result, protocol) + "\n")
else:
for inp, result in zip(inputs, results):
file_out.write(record.encode(inp, result, protocol) + "\n")
# pylint: disable=import-outside-toplevel
from pathlib import Path
if isinstance(file_out, Path):
file_out = str(file_out)
return _callback
def log_to_database(db):
"""Save the tuning records to a database object.
Parameters
----------
db: Database
The database
"""
def _callback(_, inputs, results):
"""Callback implementation"""
for inp, result in zip(inputs, results):
db.save(inp, result)
return _callback
class Monitor(object):
"""A monitor to collect statistic during tuning"""
def __init__(self):
self.scores = []
self.timestamps = []
def __call__(self, tuner, inputs, results):
for inp, res in zip(inputs, results):
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)
self.scores.append(flops)
else:
self.scores.append(0)
self.timestamps.append(res.timestamp)
def reset(self):
self.scores = []
self.timestamps = []
def trial_scores(self):
"""get scores (currently is flops) of all trials"""
return np.array(self.scores)
def trial_timestamps(self):
"""get wall clock time stamp of all trials"""
return np.array(self.timestamps)
def progress_bar(total, prefix="", si_prefix="G"):
"""Display progress bar for tuning
Parameters
----------
total: int
The total number of trials
prefix: str
The prefix of output message
si_prefix: str
SI prefix for flops
"""
class _Context(object):
"""Context to store local variables"""
def __init__(self):
self.best_flops = 0
self.cur_flops = 0
self.ct = 0
self.total = total
def __del__(self):
if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
sys.stdout.write(" Done.\n")
ctx = _Context()
tic = time.time()
# Validate si_prefix argument
format_si_prefix(0, si_prefix)
if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
sys.stdout.write(
"\r%s Current/Best: %7.2f/%7.2f %sFLOPS | Progress: (%d/%d) "
"| %.2f s" % (prefix, 0, 0, si_prefix, 0, total, time.time() - tic)
)
sys.stdout.flush()
def _callback(tuner, inputs, results):
ctx.ct += len(inputs)
flops = 0
for inp, res in zip(inputs, results):
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)
if not logger.isEnabledFor(logging.DEBUG): # only print progress bar in non-debug mode
ctx.cur_flops = flops
ctx.best_flops = tuner.best_flops
sys.stdout.write(
"\r%s Current/Best: %7.2f/%7.2f %sFLOPS | Progress: (%d/%d) "
"| %.2f s"
% (
prefix,
format_si_prefix(ctx.cur_flops, si_prefix),
format_si_prefix(ctx.best_flops, si_prefix),
si_prefix,
ctx.ct,
ctx.total,
time.time() - tic,
)
)
sys.stdout.flush()
return _callback