| # coding: utf-8 |
| """Callback functions that can be used to track various status during epoch.""" |
| from __future__ import absolute_import |
| |
| import logging |
| import math |
| import sys |
| import time |
| from .model import save_checkpoint |
| |
| def module_checkpoint(mod, prefix, period=1, save_optimizer_states=False): |
| """Callback to checkpoint Module to prefix every epoch. |
| |
| Parameters |
| ---------- |
| mod : subclass of BaseModule |
| The module to checkpoint. |
| prefix : str |
| The file prefix for this checkpoint. |
| period : int |
| How many epochs to wait before checkpointing. Defaults to 1. |
| save_optimizer_states : bool |
| Indicates whether or not to save optimizer states for continued training. |
| |
| Returns |
| ------- |
| callback : function |
| The callback function that can be passed as iter_end_callback to fit. |
| """ |
| period = int(max(1, period)) |
| # pylint: disable=unused-argument |
| def _callback(iter_no, sym=None, arg=None, aux=None): |
| """The checkpoint function.""" |
| if (iter_no + 1) % period == 0: |
| mod.save_checkpoint(prefix, iter_no + 1, save_optimizer_states) |
| return _callback |
| |
| |
| def do_checkpoint(prefix, period=1): |
| """Callback to checkpoint the model to prefix every epoch. |
| |
| Parameters |
| ---------- |
| prefix : str |
| The file prefix for this checkpoint. |
| period : int |
| How many epochs to wait before checkpointing. Defaults to 1. |
| |
| Returns |
| ------- |
| callback : function |
| The callback function that can be passed as ``iter_end_callback`` to fit. |
| """ |
| period = int(max(1, period)) |
| def _callback(iter_no, sym, arg, aux): |
| """The checkpoint function.""" |
| if (iter_no + 1) % period == 0: |
| save_checkpoint(prefix, iter_no + 1, sym, arg, aux) |
| return _callback |
| |
| |
| def log_train_metric(period, auto_reset=False): |
| """Callback to log the training evaluation result every period. |
| |
| Parameters |
| ---------- |
| period : int |
| The number of batch to log the training evaluation metric. |
| auto_reset : bool |
| Reset the metric after each log. |
| |
| Returns |
| ------- |
| callback : function |
| The callback function that can be passed as iter_epoch_callback to fit. |
| """ |
| def _callback(param): |
| """The checkpoint function.""" |
| if param.nbatch % period == 0 and param.eval_metric is not None: |
| name_value = param.eval_metric.get_name_value() |
| for name, value in name_value: |
| logging.info('Iter[%d] Batch[%d] Train-%s=%f', |
| param.epoch, param.nbatch, name, value) |
| if auto_reset: |
| param.eval_metric.reset() |
| return _callback |
| |
| |
| class Speedometer(object): |
| """Calculate and log training speed periodically. |
| |
| Parameters |
| ---------- |
| batch_size: int |
| batch_size of data. |
| frequent: int |
| How many batches between calculations. |
| Defaults to calculating & logging every 50 batches. |
| """ |
| def __init__(self, batch_size, frequent=50): |
| self.batch_size = batch_size |
| self.frequent = frequent |
| self.init = False |
| self.tic = 0 |
| self.last_count = 0 |
| |
| def __call__(self, param): |
| """Callback to Show speed.""" |
| count = param.nbatch |
| if self.last_count > count: |
| self.init = False |
| self.last_count = count |
| |
| if self.init: |
| if count % self.frequent == 0: |
| speed = self.frequent * self.batch_size / (time.time() - self.tic) |
| if param.eval_metric is not None: |
| name_value = param.eval_metric.get_name_value() |
| param.eval_metric.reset() |
| for name, value in name_value: |
| logging.info('Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec\tTrain-%s=%f', |
| param.epoch, count, speed, name, value) |
| else: |
| logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec", |
| param.epoch, count, speed) |
| self.tic = time.time() |
| else: |
| self.init = True |
| self.tic = time.time() |
| |
| |
| class ProgressBar(object): |
| """Show a progress bar. |
| |
| Parameters |
| ---------- |
| total: int |
| total batch size |
| length: int |
| length or progress bar |
| """ |
| def __init__(self, total, length=80): |
| self.bar_len = length |
| self.total = total |
| |
| def __call__(self, param): |
| """Callback to Show progress bar.""" |
| count = param.nbatch |
| filled_len = int(round(self.bar_len * count / float(self.total))) |
| percents = math.ceil(100.0 * count / float(self.total)) |
| prog_bar = '=' * filled_len + '-' * (self.bar_len - filled_len) |
| sys.stdout.write('[%s] %s%s\r' % (prog_bar, percents, '%')) |
| |
| |
| class LogValidationMetricsCallback(object): |
| """Just logs the eval metrics at the end of an epoch.""" |
| |
| def __call__(self, param): |
| if not param.eval_metric: |
| return |
| name_value = param.eval_metric.get_name_value() |
| for name, value in name_value: |
| logging.info('Epoch[%d] Validation-%s=%f', param.epoch, name, value) |