blob: dd78e73b9b84a66570b4b4f6aae85ba732ce3d2b [file] [log] [blame]
# pylint: skip-file
import numpy as np
import mxnet as mx
import time
import logging
from collections import namedtuple
from mxnet import optimizer as opt
from mxnet.optimizer import get_updater
from mxnet import metric
# Parameter to pass to batch_end_callback
BatchEndParam = namedtuple('BatchEndParams', ['epoch', 'nbatch', 'eval_metric'])
class Solver(object):
def __init__(self, symbol, ctx=None,
begin_epoch=0, num_epoch=None,
arg_params=None, aux_params=None,
optimizer='sgd', **kwargs):
self.symbol = symbol
if ctx is None:
ctx = mx.cpu()
self.ctx = ctx
self.begin_epoch = begin_epoch
self.num_epoch = num_epoch
self.arg_params = arg_params
self.aux_params = aux_params
self.optimizer = optimizer
self.kwargs = kwargs.copy()
def fit(self, train_data, eval_data=None,
eval_metric='acc',
grad_req='write',
epoch_end_callback=None,
batch_end_callback=None,
kvstore='local',
logger=None):
if logger is None:
logger = logging
logging.info('Start training with %s', str(self.ctx))
arg_shapes, out_shapes, aux_shapes = self.symbol.infer_shape(data=train_data.provide_data[0][1])
arg_names = self.symbol.list_arguments()
if grad_req != 'null':
self.grad_params = {}
for name, shape in zip(arg_names, arg_shapes):
if not (name.endswith('data') or name.endswith('label')):
self.grad_params[name] = mx.nd.zeros(shape, self.ctx)
else:
self.grad_params = None
aux_names = self.symbol.list_auxiliary_states()
self.aux_params = {k : nd.zeros(s) for k, s in zip(aux_names, aux_shapes)}
data_name = train_data.data_name
label_name = train_data.label_name
input_names = [data_name, label_name]
self.optimizer = opt.create(self.optimizer, rescale_grad=(1.0/train_data.get_batch_size()), **(self.kwargs))
self.updater = get_updater(self.optimizer)
eval_metric = metric.create(eval_metric)
# begin training
for epoch in range(self.begin_epoch, self.num_epoch):
nbatch = 0
train_data.reset()
eval_metric.reset()
for data in train_data:
nbatch += 1
label_shape = data[label_name].shape
self.arg_params[data_name] = mx.nd.array(data[data_name], self.ctx)
self.arg_params[label_name] = mx.nd.array(data[label_name].reshape(label_shape[0], \
label_shape[1]*label_shape[2]), self.ctx)
output_names = self.symbol.list_outputs()
self.exector = self.symbol.bind(self.ctx, self.arg_params,
args_grad=self.grad_params,
grad_req=grad_req,
aux_states=self.aux_params)
assert len(self.symbol.list_arguments()) == len(self.exector.grad_arrays)
update_dict = {name: nd for name, nd in zip(self.symbol.list_arguments(), \
self.exector.grad_arrays) if nd is not None}
output_dict = {}
output_buff = {}
for key, arr in zip(self.symbol.list_outputs(), self.exector.outputs):
output_dict[key] = arr
output_buff[key] = mx.nd.empty(arr.shape, ctx=mx.cpu())
self.exector.forward(is_train=True)
for key in output_dict:
output_dict[key].copyto(output_buff[key])
self.exector.backward()
for key, arr in update_dict.items():
if key != "bigscore_weight":
self.updater(key, arr, self.arg_params[key])
pred_shape = self.exector.outputs[0].shape
label = mx.nd.array(data[label_name].reshape(label_shape[0], label_shape[1]*label_shape[2]))
pred = mx.nd.array(output_buff["softmax_output"].asnumpy().reshape(pred_shape[0], \
pred_shape[1], pred_shape[2]*pred_shape[3]))
eval_metric.update([label], [pred])
self.exector.outputs[0].wait_to_read()
batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric)
batch_end_callback(batch_end_params)
if epoch_end_callback is not None:
epoch_end_callback(epoch, self.symbol, self.arg_params, self.aux_params)
name, value = eval_metric.get()
logger.info(" --->Epoch[%d] Train-%s=%f", epoch, name, value)
# evaluation
if eval_data:
logger.info(" in eval process...")
nbatch = 0
eval_data.reset()
eval_metric.reset()
for data in eval_data:
nbatch += 1
label_shape = data[label_name].shape
self.arg_params[data_name] = mx.nd.array(data[data_name], self.ctx)
self.arg_params[label_name] = mx.nd.array(data[label_name].reshape(label_shape[0], \
label_shape[1]*label_shape[2]), self.ctx)
exector = self.symbol.bind(self.ctx, self.arg_params,
args_grad=self.grad_params,
grad_req=grad_req,
aux_states=self.aux_params)
cpu_output_array = mx.nd.zeros(exector.outputs[0].shape)
exector.forward(is_train=False)
exector.outputs[0].copyto(cpu_output_array)
pred_shape = cpu_output_array.shape
label = mx.nd.array(data[label_name].reshape(label_shape[0], \
label_shape[1]*label_shape[2]))
pred = mx.nd.array(cpu_output_array.asnumpy().reshape(pred_shape[0], \
pred_shape[1], pred_shape[2]*pred_shape[3]))
eval_metric.update([label], [pred])
exector.outputs[0].wait_to_read()
name, value = eval_metric.get()
logger.info('batch[%d] Validation-%s=%f', nbatch, name, value)