| # coding: utf-8 |
| # pylint: disable=too-many-arguments, no-member |
| """Functions for constructing recurrent neural networks.""" |
| import warnings |
| |
| from ..model import save_checkpoint, load_checkpoint |
| from .rnn_cell import BaseRNNCell |
| |
| def rnn_unroll(cell, length, inputs=None, begin_state=None, input_prefix='', layout='NTC'): |
| """Deprecated. Please use cell.unroll instead""" |
| warnings.warn('rnn_unroll is deprecated. Please call cell.unroll directly.') |
| return cell.unroll(length=length, inputs=inputs, begin_state=begin_state, |
| input_prefix=input_prefix, layout=layout) |
| |
| def save_rnn_checkpoint(cells, prefix, epoch, symbol, arg_params, aux_params): |
| """Save checkpoint for model using RNN cells. |
| Unpacks weight before saving. |
| |
| Parameters |
| ---------- |
| cells : RNNCells or list of RNNCells |
| The RNN cells used by this symbol. |
| prefix : str |
| Prefix of model name. |
| epoch : int |
| The epoch number of the model. |
| symbol : Symbol |
| The input symbol |
| arg_params : dict of str to NDArray |
| Model parameter, dict of name to NDArray of net's weights. |
| aux_params : dict of str to NDArray |
| Model parameter, dict of name to NDArray of net's auxiliary states. |
| |
| Notes |
| ----- |
| - ``prefix-symbol.json`` will be saved for symbol. |
| - ``prefix-epoch.params`` will be saved for parameters. |
| """ |
| if isinstance(cells, BaseRNNCell): |
| cells = [cells] |
| for cell in cells: |
| arg_params = cell.unpack_weights(arg_params) |
| save_checkpoint(prefix, epoch, symbol, arg_params, aux_params) |
| |
| def load_rnn_checkpoint(cells, prefix, epoch): |
| """Load model checkpoint from file. |
| Pack weights after loading. |
| |
| Parameters |
| ---------- |
| cells : RNNCells or list of RNNCells |
| The RNN cells used by this symbol. |
| prefix : str |
| Prefix of model name. |
| epoch : int |
| Epoch number of model we would like to load. |
| |
| Returns |
| ------- |
| symbol : Symbol |
| The symbol configuration of computation network. |
| arg_params : dict of str to NDArray |
| Model parameter, dict of name to NDArray of net's weights. |
| aux_params : dict of str to NDArray |
| Model parameter, dict of name to NDArray of net's auxiliary states. |
| |
| Notes |
| ----- |
| - symbol will be loaded from ``prefix-symbol.json``. |
| - parameters will be loaded from ``prefix-epoch.params``. |
| """ |
| sym, arg, aux = load_checkpoint(prefix, epoch) |
| if isinstance(cells, BaseRNNCell): |
| cells = [cells] |
| for cell in cells: |
| arg = cell.pack_weights(arg) |
| |
| return sym, arg, aux |
| |
| def do_rnn_checkpoint(cells, prefix, period=1): |
| """Make a callback to checkpoint Module to prefix every epoch. |
| unpacks weights used by cells before saving. |
| |
| Parameters |
| ---------- |
| cells : subclass of BaseRNNCell |
| RNN cells used by this module. |
| prefix : str |
| The file prefix to checkpoint to |
| period : int |
| How many epochs to wait before checkpointing. Default is 1. |
| |
| Returns |
| ------- |
| callback : function |
| The callback function that can be passed as iter_end_callback to fit. |
| """ |
| period = int(max(1, period)) |
| # pylint: disable=unused-argument |
| def _callback(iter_no, sym=None, arg=None, aux=None): |
| """The checkpoint function.""" |
| if (iter_no + 1) % period == 0: |
| save_rnn_checkpoint(cells, prefix, iter_no+1, sym, arg, aux) |
| return _callback |