| # pylint: skip-file | |
| import numpy as np | |
| import mxnet as mx | |
| import time | |
| import logging | |
| from collections import namedtuple | |
| from mxnet import optimizer as opt | |
| from mxnet.optimizer import get_updater | |
| from mxnet import metric | |
| # Parameter to pass to batch_end_callback | |
| BatchEndParam = namedtuple('BatchEndParams', ['epoch', 'nbatch', 'eval_metric']) | |
| class Solver(object): | |
| def __init__(self, symbol, ctx=None, | |
| begin_epoch=0, num_epoch=None, | |
| arg_params=None, aux_params=None, | |
| optimizer='sgd', **kwargs): | |
| self.symbol = symbol | |
| if ctx is None: | |
| ctx = mx.cpu() | |
| self.ctx = ctx | |
| self.begin_epoch = begin_epoch | |
| self.num_epoch = num_epoch | |
| self.arg_params = arg_params | |
| self.aux_params = aux_params | |
| self.optimizer = optimizer | |
| self.kwargs = kwargs.copy() | |
| def fit(self, train_data, eval_data=None, | |
| eval_metric='acc', | |
| grad_req='write', | |
| epoch_end_callback=None, | |
| batch_end_callback=None, | |
| kvstore='local', | |
| logger=None): | |
| if logger is None: | |
| logger = logging | |
| logging.info('Start training with %s', str(self.ctx)) | |
| arg_shapes, out_shapes, aux_shapes = self.symbol.infer_shape(data=train_data.provide_data[0][1]) | |
| arg_names = self.symbol.list_arguments() | |
| if grad_req != 'null': | |
| self.grad_params = {} | |
| for name, shape in zip(arg_names, arg_shapes): | |
| if not (name.endswith('data') or name.endswith('label')): | |
| self.grad_params[name] = mx.nd.zeros(shape, self.ctx) | |
| else: | |
| self.grad_params = None | |
| aux_names = self.symbol.list_auxiliary_states() | |
| self.aux_params = {k : nd.zeros(s) for k, s in zip(aux_names, aux_shapes)} | |
| data_name = train_data.data_name | |
| label_name = train_data.label_name | |
| input_names = [data_name, label_name] | |
| self.optimizer = opt.create(self.optimizer, rescale_grad=(1.0/train_data.get_batch_size()), **(self.kwargs)) | |
| self.updater = get_updater(self.optimizer) | |
| eval_metric = metric.create(eval_metric) | |
| # begin training | |
| for epoch in range(self.begin_epoch, self.num_epoch): | |
| nbatch = 0 | |
| train_data.reset() | |
| eval_metric.reset() | |
| for data in train_data: | |
| nbatch += 1 | |
| label_shape = data[label_name].shape | |
| self.arg_params[data_name] = mx.nd.array(data[data_name], self.ctx) | |
| self.arg_params[label_name] = mx.nd.array(data[label_name].reshape(label_shape[0], \ | |
| label_shape[1]*label_shape[2]), self.ctx) | |
| output_names = self.symbol.list_outputs() | |
| self.exector = self.symbol.bind(self.ctx, self.arg_params, | |
| args_grad=self.grad_params, | |
| grad_req=grad_req, | |
| aux_states=self.aux_params) | |
| assert len(self.symbol.list_arguments()) == len(self.exector.grad_arrays) | |
| update_dict = {name: nd for name, nd in zip(self.symbol.list_arguments(), \ | |
| self.exector.grad_arrays) if nd is not None} | |
| output_dict = {} | |
| output_buff = {} | |
| for key, arr in zip(self.symbol.list_outputs(), self.exector.outputs): | |
| output_dict[key] = arr | |
| output_buff[key] = mx.nd.empty(arr.shape, ctx=mx.cpu()) | |
| self.exector.forward(is_train=True) | |
| for key in output_dict: | |
| output_dict[key].copyto(output_buff[key]) | |
| self.exector.backward() | |
| for key, arr in update_dict.items(): | |
| if key != "bigscore_weight": | |
| self.updater(key, arr, self.arg_params[key]) | |
| pred_shape = self.exector.outputs[0].shape | |
| label = mx.nd.array(data[label_name].reshape(label_shape[0], label_shape[1]*label_shape[2])) | |
| pred = mx.nd.array(output_buff["softmax_output"].asnumpy().reshape(pred_shape[0], \ | |
| pred_shape[1], pred_shape[2]*pred_shape[3])) | |
| eval_metric.update([label], [pred]) | |
| self.exector.outputs[0].wait_to_read() | |
| batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric) | |
| batch_end_callback(batch_end_params) | |
| if epoch_end_callback is not None: | |
| epoch_end_callback(epoch, self.symbol, self.arg_params, self.aux_params) | |
| name, value = eval_metric.get() | |
| logger.info(" --->Epoch[%d] Train-%s=%f", epoch, name, value) | |
| # evaluation | |
| if eval_data: | |
| logger.info(" in eval process...") | |
| nbatch = 0 | |
| eval_data.reset() | |
| eval_metric.reset() | |
| for data in eval_data: | |
| nbatch += 1 | |
| label_shape = data[label_name].shape | |
| self.arg_params[data_name] = mx.nd.array(data[data_name], self.ctx) | |
| self.arg_params[label_name] = mx.nd.array(data[label_name].reshape(label_shape[0], \ | |
| label_shape[1]*label_shape[2]), self.ctx) | |
| exector = self.symbol.bind(self.ctx, self.arg_params, | |
| args_grad=self.grad_params, | |
| grad_req=grad_req, | |
| aux_states=self.aux_params) | |
| cpu_output_array = mx.nd.zeros(exector.outputs[0].shape) | |
| exector.forward(is_train=False) | |
| exector.outputs[0].copyto(cpu_output_array) | |
| pred_shape = cpu_output_array.shape | |
| label = mx.nd.array(data[label_name].reshape(label_shape[0], \ | |
| label_shape[1]*label_shape[2])) | |
| pred = mx.nd.array(cpu_output_array.asnumpy().reshape(pred_shape[0], \ | |
| pred_shape[1], pred_shape[2]*pred_shape[3])) | |
| eval_metric.update([label], [pred]) | |
| exector.outputs[0].wait_to_read() | |
| name, value = eval_metric.get() | |
| logger.info('batch[%d] Validation-%s=%f', nbatch, name, value) |