| # pylint: disable=too-many-instance-attributes, too-many-arguments, protected-access, too-many-branches |
| # pylint: disable=too-many-public-methods, too-many-statements |
| """A `Module` implement the `BaseModule` API by wrapping a `Symbol` and one or |
| more `Executor` for data parallelization. |
| """ |
| |
| import os |
| import json |
| import logging |
| import warnings |
| |
| from .. import context as ctx |
| from .. import ndarray as nd |
| from .. import symbol as _sym |
| from .. import optimizer as opt |
| from .. import loss |
| from ..base import _Sentinel, __version__, string_types |
| |
| from .executor_group import DataParallelExecutorGroup |
| from ..model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore |
| from ..model import load_checkpoint |
| from ..initializer import Uniform, InitDesc |
| |
| from .base_module import BaseModule, _check_input_names, _parse_data_desc |
| |
| |
| class Module(BaseModule): |
| """Module is a basic module that wrap a `Loss` or `Symbol`. |
| It is functionally the same as the `FeedForward` model. |
| |
| Parameters |
| ---------- |
| symbol : Symbol |
| data_names : list of str |
| Default is `('data')` for a typical model used in image classification. |
| label_names : list of str |
| Default is `('softmax_label')` for a typical model used in image |
| classification. |
| logger : Logger |
| Default is `logging`. |
| context : Context or list of Context |
| Default is `cpu()`. |
| work_load_list : list of number |
| Default `None`, indicating uniform workload. |
| fixed_param_names: list of str |
| Default `None`, indicating no network parameters are fixed. |
| state_names : list of str |
| states are similar to data and label, but not provided by data iterator. |
| Instead they are initialized to 0 and can be set by set_states() |
| """ |
| def __init__(self, symbol, data_names=('data',), label_names=_Sentinel, |
| logger=logging, context=ctx.cpu(), work_load_list=None, |
| fixed_param_names=None, state_names=None, **kwargs): |
| super(Module, self).__init__(logger=logger) |
| |
| if isinstance(symbol, string_types): |
| symbol = _sym.load_json(symbol) |
| elif isinstance(symbol, dict): |
| symbol = loss.create(**symbol) |
| |
| if isinstance(context, ctx.Context): |
| context = [context] |
| self._context = context |
| if work_load_list is None: |
| work_load_list = [1] * len(self._context) |
| assert len(work_load_list) == len(self._context) |
| self._work_load_list = work_load_list |
| |
| self._kwargs = kwargs |
| if isinstance(symbol, loss.BaseLoss): |
| self._kwargs['symbol'] = symbol.get_config() |
| else: |
| self._kwargs['symbol'] = symbol.tojson() |
| self._kwargs.update({ |
| 'data_names': data_names, |
| 'fixed_param_names': fixed_param_names, |
| 'state_names': state_names, |
| '__type__': 'module', |
| '__version__': __version__}) |
| |
| if isinstance(symbol, loss.BaseLoss): |
| self._loss = symbol |
| self._symbol = _sym.Group([self._loss.output_symbol, self._loss.loss_symbol]) |
| num_output = len(self._loss.output_symbol.list_outputs()) |
| num_loss = len(self._loss.loss_symbol.list_outputs()) |
| self._output_range = (0, num_output) |
| self._loss_range = (num_output, num_output+num_loss) |
| assert label_names is _Sentinel, \ |
| "label_names has been deprecated. Do not set." |
| label_names = self._loss.label_names |
| else: |
| self._symbol = symbol |
| self._loss = None |
| self._output_range = (0, len(self._symbol.list_outputs())) |
| if label_names is _Sentinel: |
| label_names = ('softmax_label',) |
| else: |
| warnings.warn( |
| "label_names has been deprecated. For prediction only, " |
| "do not set label_names. For training, please use the new " |
| "mxnet.loss.* classes for symbol", stacklevel=2) |
| |
| data_names = list(data_names) if data_names is not None else [] |
| label_names = list(label_names) if label_names is not None else [] |
| state_names = list(state_names) if state_names is not None else [] |
| fixed_param_names = list(fixed_param_names) if fixed_param_names is not None else [] |
| |
| _check_input_names(self._symbol, data_names, "data", True) |
| _check_input_names(self._symbol, label_names, "label", False) |
| _check_input_names(self._symbol, state_names, "state", True) |
| _check_input_names(self._symbol, fixed_param_names, "fixed_param", True) |
| |
| arg_names = self._symbol.list_arguments() |
| input_names = data_names + label_names + state_names |
| self._param_names = [x for x in arg_names if x not in input_names] |
| self._fixed_param_names = fixed_param_names |
| self._aux_names = self._symbol.list_auxiliary_states() |
| self._data_names = data_names |
| self._label_names = label_names |
| self._state_names = state_names |
| self._output_names = self._symbol.list_outputs() |
| |
| self._arg_params = None |
| self._aux_params = None |
| self._params_dirty = False |
| |
| self._optimizer = None |
| self._kvstore = None |
| self._update_on_kvstore = None |
| self._updater = None |
| self._preload_opt_states = None |
| self._grad_req = None |
| |
| self._exec_group = None |
| self._data_shapes = None |
| self._label_shapes = None |
| |
| @staticmethod |
| def load(prefix, epoch, load_optimizer_states=False, **kwargs): |
| """Create a model from previously saved checkpoint. |
| |
| For example, use:: |
| mod = mx.mod.Module.load('test', 100, context=mx.gpu(0)) |
| |
| to load from "test-module.json" and "test-0100.params" |
| |
| Parameters |
| ---------- |
| prefix : str |
| path prefix of saved model files. You should have |
| "prefix-symbol.json"/"prefix-module.json", |
| "prefix-xxxx.params", and optionally "prefix-xxxx.states", |
| where xxxx is the epoch number. |
| epoch : int |
| epoch to load. |
| load_optimizer_states : bool |
| whether to load optimizer states. Checkpoint needs |
| to have been made with save_optimizer_states=True. |
| context : Context or list of Context |
| Default is `cpu()`. |
| work_load_list : list of number |
| Default `None`, indicating uniform workload. |
| logger : Logger |
| Default is `logging`. |
| """ |
| sym, args, auxs = load_checkpoint(prefix, epoch) |
| if os.path.exists('%s-module.json'%prefix): |
| config = json.loads(open('%s-module.json'%prefix).read()) |
| config.update(kwargs) |
| mod = Module(**config) |
| else: |
| mod = Module(sym, **kwargs) |
| |
| mod._arg_params = args |
| mod._aux_params = auxs |
| mod.params_initialized = True |
| if load_optimizer_states: |
| mod._preload_opt_states = '%s-%04d.states'%(prefix, epoch) |
| return mod |
| |
| def save_checkpoint(self, prefix, epoch, save_optimizer_states=False): |
| """Save current progress to checkpoint. |
| Use mx.callback.module_checkpoint as |
| epoch_end_callback to save during training. |
| |
| Outputs 'prefix-module.json', 'prefix-symbol.json', |
| 'prefix-(epoch).params', and optionally 'prefix-(epoch).states'. |
| |
| Parameters |
| ---------- |
| prefix : str |
| The file prefix to checkpoint to |
| epoch : int |
| The current epoch number |
| save_optimizer_states : bool |
| Whether to save optimizer states for continue training |
| """ |
| if self._loss is not None: |
| self._loss.output_symbol.save('%s-symbol.json'%prefix) |
| else: |
| self._symbol.save('%s-symbol.json'%prefix) |
| with open('%s-module.json'%prefix, 'w') as fout: |
| json.dump(self._kwargs, fout) |
| |
| param_name = '%s-%04d.params' % (prefix, epoch) |
| self.save_params(param_name) |
| logging.info('Saved checkpoint to \"%s\"', param_name) |
| if save_optimizer_states: |
| state_name = '%s-%04d.states' % (prefix, epoch) |
| self.save_optimizer_states(state_name) |
| logging.info('Saved optimizer state to \"%s\"', state_name) |
| |
| def _reset_bind(self): |
| """Internal function to reset binded state.""" |
| if self.binded and self.params_initialized: |
| self._sync_params_from_devices() |
| self.binded = False |
| self._exec_group = None |
| self._data_shapes = None |
| self._label_shapes = None |
| |
| @property |
| def data_names(self): |
| """A list of names for data required by this module.""" |
| return self._data_names |
| |
| @property |
| def label_names(self): |
| """A list of names for labels required by this module.""" |
| return self._label_names |
| |
| @property |
| def output_names(self): |
| """A list of names for the outputs of this module.""" |
| return self._output_names[self._output_range[0]:self._output_range[1]] |
| |
| @property |
| def data_shapes(self): |
| """Get data shapes. |
| |
| Returns |
| ------- |
| A list of `(name, shape)` pairs. |
| """ |
| assert self.binded |
| return self._data_shapes |
| |
| @property |
| def label_shapes(self): |
| """Get label shapes. |
| |
| Returns |
| ------- |
| A list of `(name, shape)` pairs. The return value could be `None` if |
| the module does not need labels, or if the module is not binded for |
| training (in this case, label information is not available). |
| """ |
| assert self.binded |
| return self._label_shapes |
| |
| @property |
| def output_shapes(self): |
| """Get output shapes. |
| |
| Returns |
| ------- |
| A list of `(name, shape)` pairs. |
| """ |
| assert self.binded |
| return self._exec_group.get_output_shapes() |
| |
| def get_params(self): |
| """Get current parameters. |
| Returns |
| ------- |
| `(arg_params, aux_params)`, each a dictionary of name to parameters (in |
| `NDArray`) mapping. |
| """ |
| assert self.binded and self.params_initialized |
| |
| if self._params_dirty: |
| self._sync_params_from_devices() |
| return (self._arg_params, self._aux_params) |
| |
| def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None, |
| allow_missing=False, force_init=False): |
| """Initialize the parameters and auxiliary states. |
| |
| Parameters |
| ---------- |
| initializer : Initializer |
| Called to initialize parameters if needed. |
| arg_params : dict |
| If not None, should be a dictionary of existing arg_params. Initialization |
| will be copied from that. |
| aux_params : dict |
| If not None, should be a dictionary of existing aux_params. Initialization |
| will be copied from that. |
| allow_missing : bool |
| If true, params could contain missing values, and the initializer will be |
| called to fill those missing params. |
| force_init : bool |
| If true, will force re-initialize even if already initialized. |
| """ |
| if self.params_initialized and not force_init: |
| warnings.warn("Parameters already initialized and force_init=False. " |
| "init_params call ignored.", stacklevel=2) |
| return |
| assert self.binded, 'call bind before initializing the parameters' |
| |
| def _impl(name, arr, cache): |
| """Internal helper for parameter initialization""" |
| if cache is not None: |
| if name in cache: |
| cache_arr = cache[name] |
| |
| # just in case the cached array is just the target itself |
| if cache_arr is not arr: |
| cache_arr.copyto(arr) |
| else: |
| if not allow_missing: |
| raise RuntimeError("%s is not presented" % name) |
| if initializer != None: |
| initializer(name, arr) |
| else: |
| initializer(name, arr) |
| |
| attrs = self._symbol.attr_dict() |
| for name, arr in self._arg_params.items(): |
| desc = InitDesc(name, attrs.get(name, None)) |
| _impl(desc, arr, arg_params) |
| |
| for name, arr in self._aux_params.items(): |
| desc = InitDesc(name, attrs.get(name, None)) |
| _impl(desc, arr, aux_params) |
| |
| self.params_initialized = True |
| self._params_dirty = False |
| |
| # copy the initialized parameters to devices |
| self._exec_group.set_params(self._arg_params, self._aux_params) |
| |
| def set_params(self, arg_params, aux_params, allow_missing=False, force_init=True): |
| """Assign parameter and aux state values. |
| |
| Parameters |
| ---------- |
| arg_params : dict |
| Dictionary of name to value (`NDArray`) mapping. |
| aux_params : dict |
| Dictionary of name to value (`NDArray`) mapping. |
| allow_missing : bool |
| If true, params could contain missing values, and the initializer will be |
| called to fill those missing params. |
| force_init : bool |
| If true, will force re-initialize even if already initialized. |
| |
| Examples |
| -------- |
| An example of setting module parameters:: |
| >>> sym, arg_params, aux_params = \ |
| >>> mx.model.load_checkpoint(model_prefix, n_epoch_load) |
| >>> mod.set_params(arg_params=arg_params, aux_params=aux_params) |
| """ |
| if not allow_missing: |
| self.init_params(initializer=None, arg_params=arg_params, aux_params=aux_params, |
| allow_missing=allow_missing, force_init=force_init) |
| return |
| |
| if self.params_initialized and not force_init: |
| warnings.warn("Parameters already initialized and force_init=False. " |
| "set_params call ignored.", stacklevel=2) |
| return |
| |
| self._exec_group.set_params(arg_params, aux_params) |
| |
| # because we didn't update self._arg_params, they are dirty now. |
| self._params_dirty = True |
| self.params_initialized = True |
| |
| def bind(self, data_shapes, label_shapes=None, for_training=True, |
| inputs_need_grad=False, force_rebind=False, shared_module=None, |
| grad_req='write'): |
| """Bind the symbols to construct executors. This is necessary before one |
| can perform computation with the module. |
| |
| Parameters |
| ---------- |
| data_shapes : list of (str, tuple) |
| Typically is `data_iter.provide_data`. |
| label_shapes : list of (str, tuple) |
| Typically is `data_iter.provide_label`. |
| for_training : bool |
| Default is `True`. Whether the executors should be bind for training. |
| inputs_need_grad : bool |
| Default is `False`. Whether the gradients to the input data need to be computed. |
| Typically this is not needed. But this might be needed when implementing composition |
| of modules. |
| force_rebind : bool |
| Default is `False`. This function does nothing if the executors are already |
| binded. But with this `True`, the executors will be forced to rebind. |
| shared_module : Module |
| Default is `None`. This is used in bucketing. When not `None`, the shared module |
| essentially corresponds to a different bucket -- a module with different symbol |
| but with the same sets of parameters (e.g. unrolled RNNs with different lengths). |
| """ |
| # force rebinding is typically used when one want to switch from |
| # training to prediction phase. |
| if force_rebind: |
| self._reset_bind() |
| |
| if self.binded: |
| self.logger.warning('Already binded, ignoring bind()') |
| return |
| |
| self.for_training = for_training |
| self.inputs_need_grad = inputs_need_grad |
| self.binded = True |
| self._grad_req = grad_req |
| |
| if not for_training: |
| assert not inputs_need_grad |
| elif self._loss is None: |
| warnings.warn( |
| "Training with ***Output symbols is deprecated. " |
| "Please use mxnet.loss.* classes instead. See " |
| "mxnet.mod.Module's document for usage") |
| |
| self._data_shapes, self._label_shapes = _parse_data_desc( |
| self.data_names, self.label_names, data_shapes, label_shapes) |
| |
| if shared_module is not None: |
| assert isinstance(shared_module, Module) and \ |
| shared_module.binded and shared_module.params_initialized |
| shared_group = shared_module._exec_group |
| else: |
| shared_group = None |
| |
| self._exec_group = DataParallelExecutorGroup( |
| self._symbol, self._context, self._work_load_list, self._data_shapes, |
| self._label_shapes, self._param_names, for_training, inputs_need_grad, |
| shared_group, logger=self.logger, fixed_param_names=self._fixed_param_names, |
| grad_req=grad_req, state_names=self._state_names) |
| |
| self._total_exec_bytes = self._exec_group._total_exec_bytes |
| if shared_module is not None: |
| self.params_initialized = True |
| self._arg_params = shared_module._arg_params |
| self._aux_params = shared_module._aux_params |
| elif self.params_initialized: |
| # if the parameters are already initialized, we are re-binding |
| # so automatically copy the already initialized params |
| self._exec_group.set_params(self._arg_params, self._aux_params) |
| else: |
| assert self._arg_params is None and self._aux_params is None |
| param_arrays = [ |
| nd.zeros(x[0].shape, dtype=x[0].dtype) |
| for x in self._exec_group.param_arrays |
| ] |
| self._arg_params = {name:arr for name, arr in zip(self._param_names, param_arrays)} |
| |
| aux_arrays = [ |
| nd.zeros(x[0].shape, dtype=x[0].dtype) |
| for x in self._exec_group.aux_arrays |
| ] |
| self._aux_params = {name:arr for name, arr in zip(self._aux_names, aux_arrays)} |
| |
| if shared_module is not None and shared_module.optimizer_initialized: |
| self.borrow_optimizer(shared_module) |
| |
| |
| def reshape(self, data_shapes, label_shapes=None): |
| """Reshape the module for new input shapes. |
| |
| Parameters |
| ---------- |
| data_shapes : list of (str, tuple) |
| Typically is `data_iter.provide_data`. |
| label_shapes : list of (str, tuple) |
| Typically is `data_iter.provide_label`. |
| """ |
| assert self.binded |
| self._data_shapes, self._label_shapes = _parse_data_desc( |
| self.data_names, self.label_names, data_shapes, label_shapes) |
| |
| self._exec_group.reshape(self._data_shapes, self._label_shapes) |
| |
| def init_optimizer(self, kvstore='local', optimizer='sgd', |
| optimizer_params=(('learning_rate', 0.01),), force_init=False): |
| """Install and initialize optimizers. |
| |
| Parameters |
| ---------- |
| kvstore : str or KVStore |
| Default `'local'`. |
| optimizer : str or Optimizer |
| Default `'sgd'` |
| optimizer_params : dict |
| Default `(('learning_rate', 0.01),)`. The default value is not a dictionary, |
| just to avoid pylint warning of dangerous default values. |
| force_init : bool |
| Default `False`, indicating whether we should force re-initializing the |
| optimizer in the case an optimizer is already installed. |
| """ |
| assert self.binded and self.params_initialized |
| |
| if self.optimizer_initialized and not force_init: |
| self.logger.warning('optimizer already initialized, ignoring...') |
| return |
| |
| (kvstore, update_on_kvstore) = \ |
| _create_kvstore(kvstore, len(self._context), self._arg_params) |
| |
| batch_size = self._exec_group.batch_size |
| if kvstore and 'dist' in kvstore.type and '_sync' in kvstore.type: |
| batch_size *= kvstore.num_workers |
| rescale_grad = 1.0/batch_size |
| |
| if isinstance(optimizer, str): |
| idx2name = {} |
| if update_on_kvstore: |
| idx2name.update(enumerate(self._exec_group.param_names)) |
| else: |
| for k in range(len(self._context)): |
| idx2name.update({i*len(self._context)+k: n |
| for i, n in enumerate(self._exec_group.param_names)}) |
| optimizer_params = dict(optimizer_params) |
| if 'rescale_grad' not in optimizer_params: |
| optimizer_params['rescale_grad'] = rescale_grad |
| optimizer = opt.create(optimizer, |
| sym=self._symbol, param_idx2name=idx2name, |
| **optimizer_params) |
| else: |
| assert isinstance(optimizer, opt.Optimizer) |
| if optimizer.rescale_grad != rescale_grad: |
| #pylint: disable=no-member |
| warnings.warn( |
| "Optimizer created manually outside Module but rescale_grad " + |
| "is not normalized to 1.0/batch_size/num_workers (%s vs. %s). "%( |
| optimizer.rescale_grad, rescale_grad) + |
| "Is this intended?", stacklevel=2) |
| |
| self._optimizer = optimizer |
| self._kvstore = kvstore |
| self._update_on_kvstore = update_on_kvstore |
| self._updater = None |
| |
| if kvstore: |
| # copy initialized local parameters to kvstore |
| _initialize_kvstore(kvstore=kvstore, |
| param_arrays=self._exec_group.param_arrays, |
| arg_params=self._arg_params, |
| param_names=self._param_names, |
| update_on_kvstore=update_on_kvstore) |
| if update_on_kvstore: |
| kvstore.set_optimizer(self._optimizer) |
| else: |
| self._updater = opt.get_updater(optimizer) |
| |
| self.optimizer_initialized = True |
| |
| if self._preload_opt_states is not None: |
| self.load_optimizer_states(self._preload_opt_states) |
| self._preload_opt_states = None |
| |
| def borrow_optimizer(self, shared_module): |
| """Borrow optimizer from a shared module. Used in bucketing, where exactly the same |
| optimizer (esp. kvstore) is used. |
| |
| Parameters |
| ---------- |
| shared_module : Module |
| """ |
| assert shared_module.optimizer_initialized |
| self._optimizer = shared_module._optimizer |
| self._kvstore = shared_module._kvstore |
| self._update_on_kvstore = shared_module._update_on_kvstore |
| self._updater = shared_module._updater |
| self.optimizer_initialized = True |
| |
| def forward(self, data_batch, is_train=None): |
| """Forward computation. |
| |
| Parameters |
| ---------- |
| data_batch : DataBatch |
| Could be anything with similar API implemented. |
| is_train : bool |
| Default is `None`, which means `is_train` takes the value of `self.for_training`. |
| """ |
| assert self.binded and self.params_initialized |
| self._exec_group.forward(data_batch, is_train) |
| |
| def backward(self, out_grads=None): |
| """Backward computation. |
| |
| Parameters |
| ---------- |
| out_grads : NDArray or list of NDArray, optional |
| Gradient on the outputs to be propagated back. |
| This parameter is only needed when bind is called |
| on outputs that are not a loss function. |
| """ |
| assert self.binded and self.params_initialized |
| self._exec_group.backward(out_grads=out_grads) |
| |
| def update(self): |
| """Update parameters according to the installed optimizer and the gradients computed |
| in the previous forward-backward batch. |
| """ |
| assert self.binded and self.params_initialized and self.optimizer_initialized |
| |
| self._params_dirty = True |
| if self._update_on_kvstore: |
| _update_params_on_kvstore(self._exec_group.param_arrays, |
| self._exec_group.grad_arrays, |
| self._kvstore) |
| else: |
| _update_params(self._exec_group.param_arrays, |
| self._exec_group.grad_arrays, |
| updater=self._updater, |
| num_device=len(self._context), |
| kvstore=self._kvstore) |
| |
| def get_outputs(self, merge_multi_context=True): |
| """Get outputs of the previous forward computation. |
| |
| If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it |
| is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output |
| elements are `NDArray`. When `merge_multi_context` is `False`, those `NDArray` |
| might live on different devices. |
| |
| Parameters |
| ---------- |
| merge_multi_context : bool |
| Default is `True`. In the case when data-parallelism is used, the outputs |
| will be collected from multiple devices. A `True` value indicate that we |
| should merge the collected results so that they look like from a single |
| executor. |
| |
| Returns |
| ------- |
| list of NDArray or list of list of NDArray |
| Output |
| """ |
| assert self.binded and self.params_initialized |
| return self._exec_group.get_outputs(merge_multi_context=merge_multi_context, |
| begin=self._output_range[0], |
| end=self._output_range[1]) |
| |
| def get_input_grads(self, merge_multi_context=True): |
| """Get the gradients with respect to the inputs of the module. |
| |
| If `merge_multi_context` is `True`, it is like `[grad1, grad2]`. Otherwise, it |
| is like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`. All the output |
| elements are `NDArray`. |
| |
| Parameters |
| ---------- |
| merge_multi_context : bool |
| Default is `True`. In the case when data-parallelism is used, the outputs |
| will be collected from multiple devices. A `True` value indicate that we |
| should merge the collected results so that they look like from a single |
| executor. |
| |
| Returns |
| ------- |
| list of NDArray or list of list of NDArray |
| Input gradients |
| """ |
| assert self.binded and self.params_initialized and self.inputs_need_grad |
| return self._exec_group.get_input_grads(merge_multi_context=merge_multi_context) |
| |
| def get_states(self, merge_multi_context=True): |
| """Get states from all devices |
| |
| If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it |
| is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output |
| elements are `NDArray`. |
| |
| Parameters |
| ---------- |
| merge_multi_context : bool |
| Default is `True`. In the case when data-parallelism is used, the states |
| will be collected from multiple devices. A `True` value indicate that we |
| should merge the collected results so that they look like from a single |
| executor. |
| |
| Returns |
| ------- |
| list of NDArray or list of list of NDArray |
| States |
| """ |
| assert self.binded and self.params_initialized |
| return self._exec_group.get_states(merge_multi_context=merge_multi_context) |
| |
| def set_states(self, states=None, value=None): |
| """Set value for states. Only one of states & value can be specified. |
| |
| Parameters |
| ---------- |
| states : list of list of NDArrays |
| source states arrays formatted like [[state1_dev1, state1_dev2], |
| [state2_dev1, state2_dev2]]. |
| value : number |
| a single scalar value for all state arrays. |
| """ |
| assert self.binded and self.params_initialized |
| self._exec_group.set_states(states, value) |
| |
| def update_metric(self, eval_metric, labels): |
| """Evaluate and accumulate evaluation metric on outputs of the last forward computation. |
| |
| Parameters |
| ---------- |
| eval_metric : EvalMetric |
| labels : list of NDArray |
| Typically `data_batch.label`. |
| """ |
| self._exec_group.update_metric(eval_metric, labels) |
| |
| def _sync_params_from_devices(self): |
| """Synchronize parameters from devices to CPU. This function should be called after |
| calling `update` that updates the parameters on the devices, before one can read the |
| latest parameters from `self._arg_params` and `self._aux_params`. |
| """ |
| self._exec_group.get_params(self._arg_params, self._aux_params) |
| self._params_dirty = False |
| |
| def save_optimizer_states(self, fname): |
| """Save optimizer (updater) state to file |
| |
| Parameters |
| ---------- |
| fname : str |
| Path to output states file. |
| """ |
| assert self.optimizer_initialized |
| |
| if self._update_on_kvstore: |
| self._kvstore.save_optimizer_states(fname) |
| else: |
| with open(fname, 'wb') as fout: |
| fout.write(self._updater.get_states()) |
| |
| def load_optimizer_states(self, fname): |
| """Load optimizer (updater) state from file |
| |
| Parameters |
| ---------- |
| fname : str |
| Path to input states file. |
| """ |
| assert self.optimizer_initialized |
| |
| if self._update_on_kvstore: |
| self._kvstore.load_optimizer_states(fname) |
| else: |
| self._updater.set_states(open(fname, 'rb').read()) |
| |
| def install_monitor(self, mon): |
| """ Install monitor on all executors """ |
| assert self.binded |
| self._exec_group.install_monitor(mon) |