blob: 4af8c7604b93cd6fa1ea1c5cd2344308e5212ffa [file] [log] [blame]
# pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals, too-many-lines
# pylint: disable=too-many-branches, too-many-statements
"""MXNet model module"""
from __future__ import absolute_import, print_function
import os
import time
import logging
import warnings
from collections import namedtuple
import numpy as np
from . import io
from . import nd
from . import symbol as sym
from . import optimizer as opt
from . import metric
from . import kvstore as kvs
from .context import Context, cpu
from .initializer import Uniform
from .optimizer import get_updater
from .executor_manager import DataParallelExecutorManager, _check_arguments, _load_data
from .io import DataDesc
from .base import mx_real_t
BASE_ESTIMATOR = object
try:
from sklearn.base import BaseEstimator
BASE_ESTIMATOR = BaseEstimator
except ImportError:
SKLEARN_INSTALLED = False
# Parameter to pass to batch_end_callback
BatchEndParam = namedtuple('BatchEndParams',
['epoch',
'nbatch',
'eval_metric',
'locals'])
def _create_kvstore(kvstore, num_device, arg_params):
"""Create kvstore
This function select and create a proper kvstore if given the kvstore type
Parameters
----------
kvstore : KVStore or str
The kvstore
num_device : int
The number of devices
arg_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's weights.
"""
update_on_kvstore = True
if kvstore is None:
kv = None
elif isinstance(kvstore, kvs.KVStore):
kv = kvstore
elif isinstance(kvstore, str):
# create kvstore using the string type
if num_device is 1 and 'dist' not in kvstore:
# no need to use kv for single device and single machine
kv = None
else:
kv = kvs.create(kvstore)
if kvstore is 'local':
# automatically select a proper local
max_size = max(np.prod(param.shape) for param in
arg_params.values())
if max_size > 1024 * 1024 * 16:
update_on_kvstore = False
else:
raise TypeError('kvstore must be KVStore, str or None')
if kv is None:
update_on_kvstore = False
return (kv, update_on_kvstore)
def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names,
update_on_kvstore):
""" Initialize kvstore"""
for idx, param_on_devs in enumerate(param_arrays):
kvstore.init(idx, arg_params[param_names[idx]])
if update_on_kvstore:
kvstore.pull(idx, param_on_devs, priority=-idx)
def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore):
""" Perform update of param_arrays from grad_arrays on kvstore."""
for index, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
if grad_list[0] is None:
continue
# push gradient, priority is negative index
kvstore.push(index, grad_list, priority=-index)
# pull back the weights
kvstore.pull(index, arg_list, priority=-index)
def _update_params(param_arrays, grad_arrays, updater, num_device,
kvstore=None):
""" Perform update of param_arrays from grad_arrays not on kvstore."""
for index, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
if grad_list[0] is None:
continue
if kvstore:
# push gradient, priority is negative index
kvstore.push(index, grad_list, priority=-index)
# pull back the sum gradients, to the same locations.
kvstore.pull(index, grad_list, priority=-index)
for k, p in enumerate(zip(arg_list, grad_list)):
# faked an index here, to make optimizer create diff
# state for the same index but on diff devs, TODO(mli)
# use a better solution latter
w, g = p
updater(index*num_device+k, g, w)
def _multiple_callbacks(callbacks, *args, **kwargs):
"""Sends args and kwargs to any configured callbacks.
This handles the cases where the 'callbacks' variable
is None, a single function, or a list.
"""
if isinstance(callbacks, list):
for cb in callbacks:
cb(*args, **kwargs)
return
if callbacks:
callbacks(*args, **kwargs)
def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
arg_params, aux_params,
begin_epoch, end_epoch, epoch_size, optimizer,
kvstore, update_on_kvstore,
train_data, eval_data=None, eval_metric=None,
epoch_end_callback=None, batch_end_callback=None,
logger=None, work_load_list=None, monitor=None,
eval_end_callback=None,
eval_batch_end_callback=None, sym_gen=None):
"""Internal training function on multiple devices.
This function will also work for single device as well.
Parameters
----------
symbol : Symbol
The network configuration
ctx : list of Context
The training devices.
arg_names: list of str
Name of all arguments of the network.
param_names: list of str
Name of all trainable parameters of the network.
aux_names: list of str
Name of all auxiliary states of the network.
arg_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's weights.
aux_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's auxiliary states.
begin_epoch : int
The begining training epoch.
end_epoch : int
The end training epoch.
epoch_size : int, optional
Number of batches in a epoch. In default, it is set to
ceil(num_train_examples / batch_size)
optimizer : Optimizer
The optimization algorithm
train_data : DataIter
Training data iterator.
eval_data : DataIter
Validation data iterator.
eval_metric : EvalMetric
An evaluation function or a list of evaluation functions.
epoch_end_callback : callable(epoch, symbol, arg_params, aux_states)
A callback that is invoked at end of each epoch.
This can be used to checkpoint model each epoch.
batch_end_callback : callable(BatchEndParams)
A callback that is invoked at end of each batch.
This can be used to measure speed, get result from evaluation metric. etc.
kvstore : KVStore
The KVStore
update_on_kvstore : bool
whether or not perform weight updating on kvstore
logger : logging logger
When not specified, default logger will be used.
work_load_list : list of float or int, optional
The list of work load for different devices,
in the same order as ctx
monitor : Monitor, optional
Monitor installed to executor,
for monitoring outputs, weights, and gradients for debugging.
Notes
-----
- This function will inplace update the NDArrays in arg_params and aux_states.
"""
if logger is None:
logger = logging
executor_manager = DataParallelExecutorManager(symbol=symbol,
sym_gen=sym_gen,
ctx=ctx,
train_data=train_data,
param_names=param_names,
arg_names=arg_names,
aux_names=aux_names,
work_load_list=work_load_list,
logger=logger)
if monitor:
executor_manager.install_monitor(monitor)
executor_manager.set_params(arg_params, aux_params)
if not update_on_kvstore:
updater = get_updater(optimizer)
if kvstore:
_initialize_kvstore(kvstore=kvstore,
param_arrays=executor_manager.param_arrays,
arg_params=arg_params,
param_names=executor_manager.param_names,
update_on_kvstore=update_on_kvstore)
if update_on_kvstore:
kvstore.set_optimizer(optimizer)
# Now start training
train_data.reset()
for epoch in range(begin_epoch, end_epoch):
# Training phase
tic = time.time()
eval_metric.reset()
nbatch = 0
# Iterate over training data.
while True:
do_reset = True
for data_batch in train_data:
executor_manager.load_data_batch(data_batch)
if monitor is not None:
monitor.tic()
executor_manager.forward(is_train=True)
executor_manager.backward()
if update_on_kvstore:
_update_params_on_kvstore(executor_manager.param_arrays,
executor_manager.grad_arrays,
kvstore)
else:
_update_params(executor_manager.param_arrays,
executor_manager.grad_arrays,
updater=updater,
num_device=len(ctx),
kvstore=kvstore)
if monitor is not None:
monitor.toc_print()
# evaluate at end, so we can lazy copy
executor_manager.update_metric(eval_metric, data_batch.label)
nbatch += 1
# batch callback (for print purpose)
if batch_end_callback != None:
batch_end_params = BatchEndParam(epoch=epoch,
nbatch=nbatch,
eval_metric=eval_metric,
locals=locals())
_multiple_callbacks(batch_end_callback, batch_end_params)
# this epoch is done possibly earlier
if epoch_size is not None and nbatch >= epoch_size:
do_reset = False
break
if do_reset:
logger.info('Epoch[%d] Resetting Data Iterator', epoch)
train_data.reset()
# this epoch is done
if epoch_size is None or nbatch >= epoch_size:
break
toc = time.time()
logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))
if epoch_end_callback or epoch + 1 == end_epoch:
executor_manager.copy_to(arg_params, aux_params)
_multiple_callbacks(epoch_end_callback, epoch, symbol, arg_params, aux_params)
# evaluation
if eval_data:
eval_metric.reset()
eval_data.reset()
total_num_batch = 0
for i, eval_batch in enumerate(eval_data):
executor_manager.load_data_batch(eval_batch)
executor_manager.forward(is_train=False)
executor_manager.update_metric(eval_metric, eval_batch.label)
if eval_batch_end_callback != None:
batch_end_params = BatchEndParam(epoch=epoch,
nbatch=i,
eval_metric=eval_metric,
locals=locals())
_multiple_callbacks(eval_batch_end_callback, batch_end_params)
total_num_batch += 1
if eval_end_callback != None:
eval_end_params = BatchEndParam(epoch=epoch,
nbatch=total_num_batch,
eval_metric=eval_metric,
locals=locals())
_multiple_callbacks(eval_end_callback, eval_end_params)
eval_data.reset()
# end of all epochs
return
def save_checkpoint(prefix, epoch, symbol, arg_params, aux_params):
"""Checkpoint the model data into file.
Parameters
----------
prefix : str
Prefix of model name.
epoch : int
The epoch number of the model.
symbol : Symbol
The input symbol
arg_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's weights.
aux_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's auxiliary states.
Notes
-----
- ``prefix-symbol.json`` will be saved for symbol.
- ``prefix-epoch.params`` will be saved for parameters.
"""
if symbol is not None:
symbol.save('%s-symbol.json' % prefix)
save_dict = {('arg:%s' % k) : v.as_in_context(cpu()) for k, v in arg_params.items()}
save_dict.update({('aux:%s' % k) : v.as_in_context(cpu()) for k, v in aux_params.items()})
param_name = '%s-%04d.params' % (prefix, epoch)
nd.save(param_name, save_dict)
logging.info('Saved checkpoint to \"%s\"', param_name)
def load_checkpoint(prefix, epoch):
"""Load model checkpoint from file.
Parameters
----------
prefix : str
Prefix of model name.
epoch : int
Epoch number of model we would like to load.
Returns
-------
symbol : Symbol
The symbol configuration of computation network.
arg_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's weights.
aux_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's auxiliary states.
Notes
-----
- symbol will be loaded from ``prefix-symbol.json``.
- parameters will be loaded from ``prefix-epoch.params``.
"""
if os.path.exists('%s-symbol.json' % prefix):
symbol = sym.load('%s-symbol.json' % prefix)
else:
symbol = None
save_dict = nd.load('%s-%04d.params' % (prefix, epoch))
arg_params = {}
aux_params = {}
for k, v in save_dict.items():
tp, name = k.split(':', 1)
if tp == 'arg':
arg_params[name] = v
if tp == 'aux':
aux_params[name] = v
return (symbol, arg_params, aux_params)
from .callback import LogValidationMetricsCallback # pylint: disable=wrong-import-position
class FeedForward(BASE_ESTIMATOR):
"""Model class of MXNet for training and predicting feedforward nets.
This class is designed for a single-data single output supervised network.
Parameters
----------
symbol : Symbol
The symbol configuration of computation network.
ctx : Context or list of Context, optional
The device context of training and prediction.
To use multi GPU training, pass in a list of gpu contexts.
num_epoch : int, optional
Training parameter, number of training epochs(epochs).
epoch_size : int, optional
Number of batches in a epoch. In default, it is set to
ceil(num_train_examples / batch_size)
optimizer : str or Optimizer, optional
Training parameter, name or optimizer object for training.
initializer : initializer function, optional
Training parameter, the initialization scheme used.
numpy_batch_size : int, optional
The batch size of training data.
Only needed when input array is numpy.
arg_params : dict of str to NDArray, optional
Model parameter, dict of name to NDArray of net's weights.
aux_params : dict of str to NDArray, optional
Model parameter, dict of name to NDArray of net's auxiliary states.
allow_extra_params : boolean, optional
Whether allow extra parameters that are not needed by symbol
to be passed by aux_params and arg_params.
If this is True, no error will be thrown when aux_params and arg_params
contain extra parameters than needed.
begin_epoch : int, optional
The begining training epoch.
kwargs : dict
The additional keyword arguments passed to optimizer.
"""
def __init__(self, symbol, ctx=None,
num_epoch=None, epoch_size=None, optimizer='sgd',
initializer=Uniform(0.01),
numpy_batch_size=128,
arg_params=None, aux_params=None,
allow_extra_params=False,
begin_epoch=0,
**kwargs):
warnings.warn(
'\033[91mmxnet.model.FeedForward has been deprecated. ' + \
'Please use mxnet.mod.Module instead.\033[0m',
DeprecationWarning, stacklevel=2)
if isinstance(symbol, sym.Symbol):
self.symbol = symbol
self.sym_gen = None
else:
assert(callable(symbol))
self.symbol = None
self.sym_gen = symbol
# model parameters
self.arg_params = arg_params
self.aux_params = aux_params
self.allow_extra_params = allow_extra_params
self.argument_checked = False
if self.sym_gen is None:
self._check_arguments()
# basic configuration
if ctx is None:
ctx = [cpu()]
elif isinstance(ctx, Context):
ctx = [ctx]
self.ctx = ctx
# training parameters
self.num_epoch = num_epoch
self.epoch_size = epoch_size
self.kwargs = kwargs.copy()
self.optimizer = optimizer
self.initializer = initializer
self.numpy_batch_size = numpy_batch_size
# internal helper state
self._pred_exec = None
self.begin_epoch = begin_epoch
def _check_arguments(self):
"""verify the argument of the default symbol and user provided parameters"""
if self.argument_checked:
return
assert(self.symbol is not None)
self.argument_checked = True
# check if symbol contain duplicated names.
_check_arguments(self.symbol)
# rematch parameters to delete useless ones
if self.allow_extra_params:
if self.arg_params:
arg_names = set(self.symbol.list_arguments())
self.arg_params = {k : v for k, v in self.arg_params.items()
if k in arg_names}
if self.aux_params:
aux_names = set(self.symbol.list_auxiliary_states())
self.aux_params = {k : v for k, v in self.aux_params.items()
if k in aux_names}
@staticmethod
def _is_data_arg(name):
"""Check if name is a data argument."""
return name.endswith('data') or name.endswith('label')
def _init_params(self, input_shapes, overwrite=False):
"""Initialize weight parameters and auxiliary states"""
arg_shapes, _, aux_shapes = self.symbol.infer_shape(**input_shapes)
assert(arg_shapes is not None)
arg_names = self.symbol.list_arguments()
input_names = input_shapes.keys()
param_names = [key for key in arg_names if key not in input_names]
aux_names = self.symbol.list_auxiliary_states()
param_name_shapes = [x for x in zip(arg_names, arg_shapes) if x[0] in param_names]
arg_params = {k : nd.zeros(s) for k, s in param_name_shapes}
aux_params = {k : nd.zeros(s) for k, s in zip(aux_names, aux_shapes)}
for k, v in arg_params.items():
if self.arg_params and k in self.arg_params and (not overwrite):
arg_params[k][:] = self.arg_params[k][:]
else:
self.initializer(k, v)
for k, v in aux_params.items():
if self.aux_params and k in self.aux_params and (not overwrite):
aux_params[k][:] = self.aux_params[k][:]
else:
self.initializer(k, v)
self.arg_params = arg_params
self.aux_params = aux_params
return (arg_names, list(param_names), aux_names)
def __getstate__(self):
this = self.__dict__.copy()
this['_pred_exec'] = None
return this
def __setstate__(self, state):
self.__dict__.update(state)
def _init_predictor(self, input_shapes, type_dict=None):
"""Initialize the predictor module for running prediction."""
if self._pred_exec is not None:
arg_shapes, _, _ = self.symbol.infer_shape(**dict(input_shapes))
assert arg_shapes is not None, "Incomplete input shapes"
pred_shapes = [x.shape for x in self._pred_exec.arg_arrays]
if arg_shapes == pred_shapes:
return
# for now only use the first device
pred_exec = self.symbol.simple_bind(
self.ctx[0], grad_req='null', type_dict=type_dict, **dict(input_shapes))
pred_exec.copy_params_from(self.arg_params, self.aux_params)
_check_arguments(self.symbol)
self._pred_exec = pred_exec
def _init_iter(self, X, y, is_train):
"""Initialize the iterator given input."""
if isinstance(X, (np.ndarray, nd.NDArray)):
if y is None:
if is_train:
raise ValueError('y must be specified when X is numpy.ndarray')
else:
y = np.zeros(X.shape[0])
if not isinstance(y, (np.ndarray, nd.NDArray)):
raise TypeError('y must be ndarray when X is numpy.ndarray')
if X.shape[0] != y.shape[0]:
raise ValueError("The numbers of data points and labels not equal")
if y.ndim == 2 and y.shape[1] == 1:
y = y.flatten()
if y.ndim != 1:
raise ValueError("Label must be 1D or 2D (with 2nd dimension being 1)")
if is_train:
return io.NDArrayIter(X, y, min(X.shape[0], self.numpy_batch_size),
shuffle=is_train, last_batch_handle='roll_over')
else:
return io.NDArrayIter(X, y, min(X.shape[0], self.numpy_batch_size), shuffle=False)
if not isinstance(X, io.DataIter):
raise TypeError('X must be DataIter, NDArray or numpy.ndarray')
return X
def _init_eval_iter(self, eval_data):
"""Initialize the iterator given eval_data."""
if eval_data is None:
return eval_data
if isinstance(eval_data, (tuple, list)) and len(eval_data) == 2:
if eval_data[0] is not None:
if eval_data[1] is None and isinstance(eval_data[0], io.DataIter):
return eval_data[0]
input_data = (np.array(eval_data[0]) if isinstance(eval_data[0], list)
else eval_data[0])
input_label = (np.array(eval_data[1]) if isinstance(eval_data[1], list)
else eval_data[1])
return self._init_iter(input_data, input_label, is_train=True)
else:
raise ValueError("Eval data is NONE")
if not isinstance(eval_data, io.DataIter):
raise TypeError('Eval data must be DataIter, or ' \
'NDArray/numpy.ndarray/list pair (i.e. tuple/list of length 2)')
return eval_data
def predict(self, X, num_batch=None, return_data=False, reset=True):
"""Run the prediction, always only use one device.
Parameters
----------
X : mxnet.DataIter
num_batch : int or None
the number of batch to run. Go though all batches if None
Returns
-------
y : numpy.ndarray or a list of numpy.ndarray if the network has multiple outputs.
The predicted value of the output.
"""
X = self._init_iter(X, None, is_train=False)
if reset:
X.reset()
data_shapes = X.provide_data
data_names = [x[0] for x in data_shapes]
type_dict = dict((key, value.dtype) for (key, value) in self.arg_params.items())
for x in X.provide_data:
if isinstance(x, DataDesc):
type_dict[x.name] = x.dtype
else:
type_dict[x[0]] = mx_real_t
self._init_predictor(data_shapes, type_dict)
batch_size = X.batch_size
data_arrays = [self._pred_exec.arg_dict[name] for name in data_names]
output_list = [[] for _ in range(len(self._pred_exec.outputs))]
if return_data:
data_list = [[] for _ in X.provide_data]
label_list = [[] for _ in X.provide_label]
i = 0
for batch in X:
_load_data(batch, data_arrays)
self._pred_exec.forward(is_train=False)
padded = batch.pad
real_size = batch_size - padded
for o_list, o_nd in zip(output_list, self._pred_exec.outputs):
o_list.append(o_nd[0:real_size].asnumpy())
if return_data:
for j, x in enumerate(batch.data):
data_list[j].append(x[0:real_size].asnumpy())
for j, x in enumerate(batch.label):
label_list[j].append(x[0:real_size].asnumpy())
i += 1
if num_batch is not None and i == num_batch:
break
outputs = [np.concatenate(x) for x in output_list]
if len(outputs) == 1:
outputs = outputs[0]
if return_data:
data = [np.concatenate(x) for x in data_list]
label = [np.concatenate(x) for x in label_list]
if len(data) == 1:
data = data[0]
if len(label) == 1:
label = label[0]
return outputs, data, label
else:
return outputs
def score(self, X, eval_metric='acc', num_batch=None, batch_end_callback=None, reset=True):
"""Run the model on X and calculate the score with eval_metric
Parameters
----------
X : mxnet.DataIter
eval_metric : metric.metric
The metric for calculating score
num_batch : int or None
the number of batch to run. Go though all batches if None
Returns
-------
s : float
the final score
"""
# setup metric
if not isinstance(eval_metric, metric.EvalMetric):
eval_metric = metric.create(eval_metric)
X = self._init_iter(X, None, is_train=False)
if reset:
X.reset()
data_shapes = X.provide_data
data_names = [x[0] for x in data_shapes]
type_dict = dict((key, value.dtype) for (key, value) in self.arg_params.items())
for x in X.provide_data:
if isinstance(x, DataDesc):
type_dict[x.name] = x.dtype
else:
type_dict[x[0]] = mx_real_t
self._init_predictor(data_shapes, type_dict)
data_arrays = [self._pred_exec.arg_dict[name] for name in data_names]
for i, batch in enumerate(X):
if num_batch is not None and i == num_batch:
break
_load_data(batch, data_arrays)
self._pred_exec.forward(is_train=False)
eval_metric.update(batch.label, self._pred_exec.outputs)
if batch_end_callback != None:
batch_end_params = BatchEndParam(epoch=0,
nbatch=i,
eval_metric=eval_metric,
locals=locals())
_multiple_callbacks(batch_end_callback, batch_end_params)
return eval_metric.get()[1]
def fit(self, X, y=None, eval_data=None, eval_metric='acc',
epoch_end_callback=None, batch_end_callback=None, kvstore='local', logger=None,
work_load_list=None, monitor=None, eval_end_callback=LogValidationMetricsCallback(),
eval_batch_end_callback=None):
"""Fit the model.
Parameters
----------
X : DataIter, or numpy.ndarray/NDArray
Training data. If X is an DataIter, the name or, if not available,
position, of its outputs should match the corresponding variable
names defined in the symbolic graph.
y : numpy.ndarray/NDArray, optional
Training set label.
If X is numpy.ndarray/NDArray, y is required to be set.
While y can be 1D or 2D (with 2nd dimension as 1), its 1st dimension must be
the same as X, i.e. the number of data points and labels should be equal.
eval_data : DataIter or numpy.ndarray/list/NDArray pair
If eval_data is numpy.ndarray/list/NDArray pair,
it should be (valid_data, valid_label).
eval_metric : metric.EvalMetric or str or callable
The evaluation metric, name of evaluation metric.
Or a customize evaluation function that returns the statistics
based on minibatch.
epoch_end_callback : callable(epoch, symbol, arg_params, aux_states)
A callback that is invoked at end of each epoch.
This can be used to checkpoint model each epoch.
batch_end_callback: callable(epoch)
A callback that is invoked at end of each batch
For print purpose
kvstore: KVStore or str, optional
The KVStore or a string kvstore type: 'local', 'dist_sync', 'dist_async'
In default uses 'local', often no need to change for single machiine.
logger : logging logger, optional
When not specified, default logger will be used.
work_load_list : float or int, optional
The list of work load for different devices,
in the same order as ctx
Note
----
KVStore behavior
- 'local', multi-devices on a single machine, will automatically choose best type.
- 'dist_sync', multi-machines with BSP
- 'dist_async', multi-machines with partical asynchronous
"""
data = self._init_iter(X, y, is_train=True)
eval_data = self._init_eval_iter(eval_data)
if self.sym_gen:
self.symbol = self.sym_gen(data.default_bucket_key) # pylint: disable=no-member
self._check_arguments()
self.kwargs["sym"] = self.symbol
arg_names, param_names, aux_names = \
self._init_params(dict(data.provide_data+data.provide_label))
# setup metric
if not isinstance(eval_metric, metric.EvalMetric):
eval_metric = metric.create(eval_metric)
# create kvstore
(kvstore, update_on_kvstore) = _create_kvstore(
kvstore, len(self.ctx), self.arg_params)
param_idx2name = {}
if update_on_kvstore:
param_idx2name.update(enumerate(param_names))
else:
for i, n in enumerate(param_names):
for k in range(len(self.ctx)):
param_idx2name[i*len(self.ctx)+k] = n
self.kwargs["param_idx2name"] = param_idx2name
# init optmizer
if isinstance(self.optimizer, str):
batch_size = data.batch_size
if kvstore and 'dist' in kvstore.type and not '_async' in kvstore.type:
batch_size *= kvstore.num_workers
optimizer = opt.create(self.optimizer,
rescale_grad=(1.0/batch_size),
**(self.kwargs))
elif isinstance(self.optimizer, opt.Optimizer):
optimizer = self.optimizer
# do training
_train_multi_device(self.symbol, self.ctx, arg_names, param_names, aux_names,
self.arg_params, self.aux_params,
begin_epoch=self.begin_epoch, end_epoch=self.num_epoch,
epoch_size=self.epoch_size,
optimizer=optimizer,
train_data=data, eval_data=eval_data,
eval_metric=eval_metric,
epoch_end_callback=epoch_end_callback,
batch_end_callback=batch_end_callback,
kvstore=kvstore, update_on_kvstore=update_on_kvstore,
logger=logger, work_load_list=work_load_list, monitor=monitor,
eval_end_callback=eval_end_callback,
eval_batch_end_callback=eval_batch_end_callback,
sym_gen=self.sym_gen)
def save(self, prefix, epoch=None):
"""Checkpoint the model checkpoint into file.
You can also use pickle to do the job if you only work on python.
The advantage of load/save is the file is language agnostic.
This means the file saved using save can be loaded by other language binding of mxnet.
You also get the benefit being able to directly load/save from cloud storage(S3, HDFS)
Parameters
----------
prefix : str
Prefix of model name.
Notes
-----
- ``prefix-symbol.json`` will be saved for symbol.
- ``prefix-epoch.params`` will be saved for parameters.
"""
if epoch is None:
epoch = self.num_epoch
assert epoch is not None
save_checkpoint(prefix, epoch, self.symbol, self.arg_params, self.aux_params)
@staticmethod
def load(prefix, epoch, ctx=None, **kwargs):
"""Load model checkpoint from file.
Parameters
----------
prefix : str
Prefix of model name.
epoch : int
epoch number of model we would like to load.
ctx : Context or list of Context, optional
The device context of training and prediction.
kwargs : dict
other parameters for model, including num_epoch, optimizer and numpy_batch_size
Returns
-------
model : FeedForward
The loaded model that can be used for prediction.
Notes
-----
- ``prefix-symbol.json`` will be saved for symbol.
- ``prefix-epoch.params`` will be saved for parameters.
"""
symbol, arg_params, aux_params = load_checkpoint(prefix, epoch)
return FeedForward(symbol, ctx=ctx,
arg_params=arg_params, aux_params=aux_params,
begin_epoch=epoch,
**kwargs)
@staticmethod
def create(symbol, X, y=None, ctx=None,
num_epoch=None, epoch_size=None, optimizer='sgd', initializer=Uniform(0.01),
eval_data=None, eval_metric='acc',
epoch_end_callback=None, batch_end_callback=None,
kvstore='local', logger=None, work_load_list=None,
eval_end_callback=LogValidationMetricsCallback(),
eval_batch_end_callback=None, **kwargs):
"""Functional style to create a model.
This function will be more consistent with functional
languages such as R, where mutation is not allowed.
Parameters
----------
symbol : Symbol
The symbol configuration of computation network.
X : DataIter
Training data
y : numpy.ndarray, optional
If X is numpy.ndarray y is required to set
ctx : Context or list of Context, optional
The device context of training and prediction.
To use multi GPU training, pass in a list of gpu contexts.
num_epoch : int, optional
Training parameter, number of training epochs(epochs).
epoch_size : int, optional
Number of batches in a epoch. In default, it is set to
ceil(num_train_examples / batch_size)
optimizer : str or Optimizer, optional
Training parameter, name or optimizer object for training.
initializier : initializer function, optional
Training parameter, the initialization scheme used.
eval_data : DataIter or numpy.ndarray pair
If eval_set is numpy.ndarray pair, it should be (valid_data, valid_label)
eval_metric : metric.EvalMetric or str or callable
The evaluation metric, name of evaluation metric.
Or a customize evaluation function that returns the statistics
based on minibatch.
epoch_end_callback : callable(epoch, symbol, arg_params, aux_states)
A callback that is invoked at end of each epoch.
This can be used to checkpoint model each epoch.
batch_end_callback: callable(epoch)
A callback that is invoked at end of each batch
For print purpose
kvstore: KVStore or str, optional
The KVStore or a string kvstore type: 'local', 'dist_sync', 'dis_async'
In default uses 'local', often no need to change for single machiine.
logger : logging logger, optional
When not specified, default logger will be used.
work_load_list : list of float or int, optional
The list of work load for different devices,
in the same order as ctx
"""
model = FeedForward(symbol, ctx=ctx, num_epoch=num_epoch,
epoch_size=epoch_size,
optimizer=optimizer, initializer=initializer, **kwargs)
model.fit(X, y, eval_data=eval_data, eval_metric=eval_metric,
epoch_end_callback=epoch_end_callback,
batch_end_callback=batch_end_callback,
kvstore=kvstore,
logger=logger,
work_load_list=work_load_list,
eval_end_callback=eval_end_callback,
eval_batch_end_callback=eval_batch_end_callback)
return model