| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| # pylint: skip-file |
| import numpy as np |
| import mxnet as mx |
| import time |
| import logging |
| from collections import namedtuple |
| from mxnet import optimizer as opt |
| from mxnet.optimizer import get_updater |
| from mxnet import metric |
| |
| # Parameter to pass to batch_end_callback |
| BatchEndParam = namedtuple('BatchEndParams', ['epoch', 'nbatch', 'eval_metric']) |
| class Solver(object): |
| def __init__(self, symbol, ctx=None, |
| begin_epoch=0, num_epoch=None, |
| arg_params=None, aux_params=None, |
| optimizer='sgd', **kwargs): |
| self.symbol = symbol |
| if ctx is None: |
| ctx = mx.cpu() |
| self.ctx = ctx |
| self.begin_epoch = begin_epoch |
| self.num_epoch = num_epoch |
| self.arg_params = arg_params |
| self.aux_params = aux_params |
| self.optimizer = optimizer |
| self.kwargs = kwargs.copy() |
| |
| def fit(self, train_data, eval_data=None, |
| eval_metric='acc', |
| grad_req='write', |
| epoch_end_callback=None, |
| batch_end_callback=None, |
| kvstore='local', |
| logger=None): |
| if logger is None: |
| logger = logging |
| logging.info('Start training with %s', str(self.ctx)) |
| arg_shapes, out_shapes, aux_shapes = self.symbol.infer_shape(data=train_data.provide_data[0][1]) |
| arg_names = self.symbol.list_arguments() |
| if grad_req != 'null': |
| self.grad_params = {} |
| for name, shape in zip(arg_names, arg_shapes): |
| if not (name.endswith('data') or name.endswith('label')): |
| self.grad_params[name] = mx.nd.zeros(shape, self.ctx) |
| else: |
| self.grad_params = None |
| aux_names = self.symbol.list_auxiliary_states() |
| self.aux_params = {k: mx.nd.zeros(s) for k, s in zip(aux_names, aux_shapes)} |
| data_name = train_data.data_name |
| label_name = train_data.label_name |
| input_names = [data_name, label_name] |
| self.optimizer = opt.create(self.optimizer, rescale_grad=(1.0/train_data.get_batch_size()), **(self.kwargs)) |
| self.updater = get_updater(self.optimizer) |
| eval_metric = metric.create(eval_metric) |
| # begin training |
| for epoch in range(self.begin_epoch, self.num_epoch): |
| nbatch = 0 |
| train_data.reset() |
| eval_metric.reset() |
| for data in train_data: |
| nbatch += 1 |
| label_shape = data[label_name].shape |
| self.arg_params[data_name] = mx.nd.array(data[data_name], self.ctx) |
| self.arg_params[label_name] = mx.nd.array(data[label_name].reshape(label_shape[0], \ |
| label_shape[1]*label_shape[2]), self.ctx) |
| output_names = self.symbol.list_outputs() |
| self.exector = self.symbol.bind(self.ctx, self.arg_params, |
| args_grad=self.grad_params, |
| grad_req=grad_req, |
| aux_states=self.aux_params) |
| assert len(self.symbol.list_arguments()) == len(self.exector.grad_arrays) |
| update_dict = {name: nd for name, nd in zip(self.symbol.list_arguments(), \ |
| self.exector.grad_arrays) if nd is not None} |
| output_dict = {} |
| output_buff = {} |
| for key, arr in zip(self.symbol.list_outputs(), self.exector.outputs): |
| output_dict[key] = arr |
| output_buff[key] = mx.nd.empty(arr.shape, ctx=mx.cpu()) |
| self.exector.forward(is_train=True) |
| for key in output_dict: |
| output_dict[key].copyto(output_buff[key]) |
| self.exector.backward() |
| for key, arr in update_dict.items(): |
| if key != "bigscore_weight": |
| self.updater(key, arr, self.arg_params[key]) |
| pred_shape = self.exector.outputs[0].shape |
| label = mx.nd.array(data[label_name].reshape(label_shape[0], label_shape[1]*label_shape[2])) |
| pred = mx.nd.array(output_buff["softmax_output"].asnumpy().reshape(pred_shape[0], \ |
| pred_shape[1], pred_shape[2]*pred_shape[3])) |
| eval_metric.update([label], [pred]) |
| self.exector.outputs[0].wait_to_read() |
| batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric) |
| batch_end_callback(batch_end_params) |
| if epoch_end_callback is not None: |
| epoch_end_callback(epoch, self.symbol, self.arg_params, self.aux_params) |
| name, value = eval_metric.get() |
| logger.info(" --->Epoch[%d] Train-%s=%f", epoch, name, value) |
| # evaluation |
| if eval_data: |
| logger.info(" in eval process...") |
| nbatch = 0 |
| eval_data.reset() |
| eval_metric.reset() |
| for data in eval_data: |
| nbatch += 1 |
| label_shape = data[label_name].shape |
| self.arg_params[data_name] = mx.nd.array(data[data_name], self.ctx) |
| self.arg_params[label_name] = mx.nd.array(data[label_name].reshape(label_shape[0], \ |
| label_shape[1]*label_shape[2]), self.ctx) |
| exector = self.symbol.bind(self.ctx, self.arg_params, |
| args_grad=self.grad_params, |
| grad_req=grad_req, |
| aux_states=self.aux_params) |
| cpu_output_array = mx.nd.zeros(exector.outputs[0].shape) |
| exector.forward(is_train=False) |
| exector.outputs[0].copyto(cpu_output_array) |
| pred_shape = cpu_output_array.shape |
| label = mx.nd.array(data[label_name].reshape(label_shape[0], \ |
| label_shape[1]*label_shape[2])) |
| pred = mx.nd.array(cpu_output_array.asnumpy().reshape(pred_shape[0], \ |
| pred_shape[1], pred_shape[2]*pred_shape[3])) |
| eval_metric.update([label], [pred]) |
| exector.outputs[0].wait_to_read() |
| name, value = eval_metric.get() |
| logger.info('batch[%d] Validation-%s=%f', nbatch, name, value) |