| # coding: utf-8 |
| # pylint: disable=no-member, invalid-name, protected-access, no-self-use |
| # pylint: disable=too-many-branches, too-many-arguments, no-self-use |
| # pylint: disable=too-many-lines, arguments-differ |
| """Definition of various recurrent neural network cells.""" |
| from __future__ import print_function |
| |
| import warnings |
| |
| from ... import symbol, init, ndarray |
| from ...base import string_types, numeric_types |
| from ..nn import Layer |
| from .. import tensor_types |
| |
| |
| def _cells_state_shape(cells): |
| return sum([c.state_shape for c in cells], []) |
| |
| def _cells_state_info(cells, batch_size): |
| return sum([c.state_info(batch_size) for c in cells], []) |
| |
| def _cells_begin_state(cells, **kwargs): |
| return sum([c.begin_state(**kwargs) for c in cells], []) |
| |
| def _cells_unpack_weights(cells, args): |
| for cell in cells: |
| args = cell.unpack_weights(args) |
| return args |
| |
| def _cells_pack_weights(cells, args): |
| for cell in cells: |
| args = cell.pack_weights(args) |
| return args |
| |
| def _get_begin_state(cell, F, begin_state, inputs, batch_size): |
| if begin_state is None: |
| if F is ndarray: |
| ctx = inputs.context if isinstance(inputs, tensor_types) else inputs[0].context |
| with ctx: |
| begin_state = cell.begin_state(func=F.zeros, batch_size=batch_size) |
| else: |
| begin_state = cell.begin_state(func=F.zeros, batch_size=batch_size) |
| return begin_state |
| |
| def _format_sequence(length, inputs, layout, merge, in_layout=None): |
| assert inputs is not None, \ |
| "unroll(inputs=None) has been deprecated. " \ |
| "Please create input variables outside unroll." |
| |
| axis = layout.find('T') |
| batch_axis = layout.find('N') |
| batch_size = 0 |
| in_axis = in_layout.find('T') if in_layout is not None else axis |
| if isinstance(inputs, symbol.Symbol): |
| F = symbol |
| if merge is False: |
| assert len(inputs.list_outputs()) == 1, \ |
| "unroll doesn't allow grouped symbol as input. Please convert " \ |
| "to list with list(inputs) first or let unroll handle splitting." |
| inputs = list(symbol.split(inputs, axis=in_axis, num_outputs=length, |
| squeeze_axis=1)) |
| elif isinstance(inputs, ndarray.NDArray): |
| F = ndarray |
| batch_size = inputs.shape[batch_axis] |
| if merge is False: |
| assert length is None or length == inputs.shape[in_axis] |
| inputs = ndarray.split(inputs, axis=in_axis, num_outputs=inputs.shape[in_axis], |
| squeeze_axis=1) |
| else: |
| assert length is None or len(inputs) == length |
| if isinstance(inputs[0], symbol.Symbol): |
| F = symbol |
| else: |
| F = ndarray |
| batch_size = inputs[0].shape[batch_axis] |
| if merge is True: |
| inputs = [F.expand_dims(i, axis=axis) for i in inputs] |
| inputs = F.concat(*inputs, dim=axis) |
| in_axis = axis |
| |
| if isinstance(inputs, tensor_types) and axis != in_axis: |
| inputs = F.swapaxes(inputs, dim1=axis, dim2=in_axis) |
| |
| return inputs, axis, F, batch_size |
| |
| |
| class RecurrentCell(Layer): |
| """Abstract base class for RNN cells |
| |
| Parameters |
| ---------- |
| prefix : str, optional |
| Prefix for names of layers |
| (this prefix is also used for names of weights if `params` is None |
| i.e. if `params` are being created and not reused) |
| params : Parameter or None, optional |
| Container for weight sharing between cells. |
| A new Parameter container is created if `params` is None. |
| """ |
| def __init__(self, prefix=None, params=None): |
| super(RecurrentCell, self).__init__(prefix=prefix, params=params) |
| self._modified = False |
| self.reset() |
| |
| def reset(self): |
| """Reset before re-using the cell for another graph.""" |
| self._init_counter = -1 |
| self._counter = -1 |
| |
| def state_info(self, batch_size=0): |
| """shape and layout information of states""" |
| raise NotImplementedError() |
| |
| @property |
| def state_shape(self): |
| """shape(s) of states""" |
| return [ele['shape'] for ele in self.state_info()] |
| |
| @property |
| def _gate_names(self): |
| """name(s) of gates""" |
| return () |
| |
| @property |
| def _curr_prefix(self): |
| return '%st%d_'%(self.prefix, self._counter) |
| |
| def begin_state(self, func=symbol.zeros, batch_size=0, **kwargs): |
| """Initial state for this cell. |
| |
| Parameters |
| ---------- |
| func : callable, default symbol.zeros |
| Function for creating initial state. |
| |
| For Symbol API, func can be symbol.zeros, symbol.uniform, |
| symbol.var etc. Use symbol.var if you want to directly |
| feed input as states. |
| |
| For NDArray API, func can be ndarray.zeros, ndarray.ones, etc. |
| batch_size: int, default 0 |
| Only required for NDArray API. Size of the batch ('N' in layout) |
| dimension of input. |
| |
| **kwargs : |
| additional keyword arguments passed to func. For example |
| mean, std, dtype, etc. |
| |
| Returns |
| ------- |
| states : nested list of Symbol |
| Starting states for the first RNN step. |
| """ |
| assert not self._modified, \ |
| "After applying modifier cells (e.g. ZoneoutCell) the base " \ |
| "cell cannot be called directly. Call the modifier cell instead." |
| states = [] |
| for info in self.state_info(batch_size): |
| self._init_counter += 1 |
| if info is not None: |
| info.update(kwargs) |
| else: |
| info = kwargs |
| state = func(name='%sbegin_state_%d'%(self._prefix, self._init_counter), |
| **info) |
| states.append(state) |
| return states |
| |
| def unpack_weights(self, args): |
| """Unpack fused weight matrices into separate |
| weight matrices. |
| |
| For example, say you use a module object `mod` to run a network that has an lstm cell. |
| In `mod.get_params()[0]`, the lstm parameters are all represented as a single big vector. |
| `cell.unpack_weights(mod.get_params()[0])` will unpack this vector into a dictionary of |
| more readable lstm parameters - c, f, i, o gates for i2h (input to hidden) and |
| h2h (hidden to hidden) weights. |
| |
| Parameters |
| ---------- |
| args : dict of str -> NDArray |
| Dictionary containing packed weights. |
| usually from `Module.get_params()[0]`. |
| |
| Returns |
| ------- |
| args : dict of str -> NDArray |
| Dictionary with unpacked weights associated with |
| this cell. |
| |
| See Also |
| -------- |
| pack_weights: Performs the reverse operation of this function. |
| """ |
| args = args.copy() |
| if not self._gate_names: |
| return args |
| h = self._num_hidden |
| for group_name in ['i2h', 'h2h']: |
| weight = args.pop('%s%s_weight'%(self._prefix, group_name)) |
| bias = args.pop('%s%s_bias' % (self._prefix, group_name)) |
| for j, gate in enumerate(self._gate_names): |
| wname = '%s%s%s_weight' % (self._prefix, group_name, gate) |
| args[wname] = weight[j*h:(j+1)*h].copy() |
| bname = '%s%s%s_bias' % (self._prefix, group_name, gate) |
| args[bname] = bias[j*h:(j+1)*h].copy() |
| return args |
| |
| def pack_weights(self, args): |
| """Pack separate weight matrices into a single packed |
| weight. |
| |
| Parameters |
| ---------- |
| args : dict of str -> NDArray |
| Dictionary containing unpacked weights. |
| |
| Returns |
| ------- |
| args : dict of str -> NDArray |
| Dictionary with packed weights associated with |
| this cell. |
| """ |
| args = args.copy() |
| if not self._gate_names: |
| return args |
| for group_name in ['i2h', 'h2h']: |
| weight = [] |
| bias = [] |
| for gate in self._gate_names: |
| wname = '%s%s%s_weight'%(self._prefix, group_name, gate) |
| weight.append(args.pop(wname)) |
| bname = '%s%s%s_bias'%(self._prefix, group_name, gate) |
| bias.append(args.pop(bname)) |
| args['%s%s_weight'%(self._prefix, group_name)] = ndarray.concatenate(weight) |
| args['%s%s_bias'%(self._prefix, group_name)] = ndarray.concatenate(bias) |
| return args |
| |
| def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None): |
| """Unroll an RNN cell across time steps. |
| |
| Parameters |
| ---------- |
| length : int |
| number of steps to unroll |
| inputs : Symbol, list of Symbol, or None |
| If `inputs` is a single Symbol (usually the output |
| of Embedding symbol), it should have shape |
| (batch_size, length, ...) if layout == 'NTC', |
| or (length, batch_size, ...) if layout == 'TNC'. |
| |
| If `inputs` is a list of symbols (usually output of |
| previous unroll), they should all have shape |
| (batch_size, ...). |
| begin_state : nested list of Symbol, optional |
| Input states created by `begin_state()` |
| or output state of another cell. |
| Created from `begin_state()` if None. |
| layout : str, optional |
| `layout` of input symbol. Only used if inputs |
| is a single Symbol. |
| merge_outputs : bool, optional |
| If False, return outputs as a list of Symbols. |
| If True, concatenate output across time steps |
| and return a single symbol with shape |
| (batch_size, length, ...) if layout == 'NTC', |
| or (length, batch_size, ...) if layout == 'TNC'. |
| If None, output whatever is faster |
| |
| Returns |
| ------- |
| outputs : list of Symbol or Symbol |
| Symbol (if `merge_outputs` is True) or list of Symbols |
| (if `merge_outputs` is False) corresponding to the output from |
| the RNN from this unrolling. |
| |
| states : list of Symbol |
| The new state of this RNN after this unrolling. |
| The type of this symbol is same as the output of begin_state(). |
| """ |
| self.reset() |
| |
| inputs, _, F, batch_size = _format_sequence(length, inputs, layout, False) |
| begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size) |
| |
| states = begin_state |
| outputs = [] |
| for i in range(length): |
| output, states = self(inputs[i], states) |
| outputs.append(output) |
| |
| outputs, _, _, _ = _format_sequence(length, outputs, layout, merge_outputs) |
| |
| return outputs, states |
| |
| #pylint: disable=no-self-use |
| def _get_activation(self, F, inputs, activation, **kwargs): |
| """Get activation function. Convert if is string""" |
| if isinstance(activation, string_types): |
| return F.Activation(inputs, act_type=activation, **kwargs) |
| else: |
| return activation(inputs, **kwargs) |
| |
| def call(self, inputs, states): |
| """Unroll the recurrent cell for one time step. |
| |
| Parameters |
| ---------- |
| inputs : sym.Variable |
| input symbol, 2D, batch_size * num_units |
| states : list of sym.Variable |
| RNN state from previous step or the output of begin_state(). |
| |
| Returns |
| ------- |
| output : Symbol |
| Symbol corresponding to the output from the RNN when unrolling |
| for a single time step. |
| states : list of Symbol |
| The new state of this RNN after this unrolling. |
| The type of this symbol is same as the output of begin_state(). |
| This can be used as input state to the next time step |
| of this RNN. |
| |
| See Also |
| -------- |
| begin_state: This function can provide the states for the first time step. |
| unroll: This function unrolls an RNN for a given number of (>=1) time steps. |
| """ |
| # pylint: disable= arguments-differ |
| self._counter += 1 |
| return super(RecurrentCell, self).call(inputs, states) |
| |
| |
| |
| class RNNCell(RecurrentCell): |
| """Simple recurrent neural network cell. |
| |
| Parameters |
| ---------- |
| num_hidden : int |
| number of units in output symbol |
| activation : str or Symbol, default 'tanh' |
| type of activation function |
| prefix : str, default 'rnn_' |
| prefix for name of layers |
| (and name of weight if params is None) |
| params : Parameter or None |
| container for weight sharing between cells. |
| created if None. |
| """ |
| def __init__(self, num_hidden, activation='tanh', num_input=0, |
| prefix=None, params=None): |
| super(RNNCell, self).__init__(prefix=prefix, params=params) |
| self._num_hidden = num_hidden |
| self._activation = activation |
| self._num_input = num_input |
| self.i2h_weight = self.params.get('i2h_weight', shape=(num_hidden, num_input)) |
| self.i2h_bias = self.params.get('i2h_bias', shape=(num_hidden,)) |
| self.h2h_weight = self.params.get('h2h_weight', shape=(num_hidden, num_hidden)) |
| self.h2h_bias = self.params.get('h2h_bias', shape=(num_hidden,)) |
| |
| def state_info(self, batch_size=0): |
| return [{'shape': (batch_size, self._num_hidden), '__layout__': 'NC'}] |
| |
| @property |
| def _gate_names(self): |
| return ('',) |
| |
| def _alias(self): |
| return 'rnn' |
| |
| def forward(self, F, inputs, states, i2h_weight, i2h_bias, |
| h2h_weight, h2h_bias): |
| name = self._curr_prefix |
| i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias, |
| num_hidden=self._num_hidden, |
| name='%si2h'%name) |
| h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias, |
| num_hidden=self._num_hidden, |
| name='%sh2h'%name) |
| output = self._get_activation(F, i2h + h2h, self._activation, |
| name='%sout'%name) |
| |
| return output, [output] |
| |
| |
| class LSTMCell(RecurrentCell): |
| """Long-Short Term Memory (LSTM) network cell. |
| |
| Parameters |
| ---------- |
| num_hidden : int |
| number of units in output symbol |
| prefix : str, default 'lstm_' |
| prefix for name of layers |
| (and name of weight if params is None) |
| params : Parameter or None |
| container for weight sharing between cells. |
| created if None. |
| forget_bias : bias added to forget gate, default 1.0. |
| Jozefowicz et al. 2015 recommends setting this to 1.0 |
| """ |
| def __init__(self, num_hidden, forget_bias=1.0, num_input=0, |
| prefix=None, params=None): |
| super(LSTMCell, self).__init__(prefix=prefix, params=params) |
| |
| self._num_hidden = num_hidden |
| self._num_input = num_input |
| self.i2h_weight = self.params.get('i2h_weight', shape=(4*num_hidden, num_input)) |
| self.h2h_weight = self.params.get('h2h_weight', shape=(4*num_hidden, num_hidden)) |
| # we add the forget_bias to i2h_bias, this adds the bias to the forget gate activation |
| self.i2h_bias = self.params.get('i2h_bias', shape=(4*num_hidden,), |
| init=init.LSTMBias(forget_bias=forget_bias)) |
| self.h2h_bias = self.params.get('h2h_bias', shape=(4*num_hidden,)) |
| |
| def state_info(self, batch_size=0): |
| return [{'shape': (batch_size, self._num_hidden), '__layout__': 'NC'}, |
| {'shape': (batch_size, self._num_hidden), '__layout__': 'NC'}] |
| |
| @property |
| def _gate_names(self): |
| return ['_i', '_f', '_c', '_o'] |
| |
| def _alias(self): |
| return 'lstm' |
| |
| def forward(self, F, inputs, states, i2h_weight, i2h_bias, |
| h2h_weight, h2h_bias): |
| name = self._curr_prefix |
| i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias, |
| num_hidden=self._num_hidden*4, |
| name='%si2h'%name) |
| h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias, |
| num_hidden=self._num_hidden*4, |
| name='%sh2h'%name) |
| gates = i2h + h2h |
| slice_gates = F.SliceChannel(gates, num_outputs=4, |
| name="%sslice"%name) |
| in_gate = F.Activation(slice_gates[0], act_type="sigmoid", |
| name='%si'%name) |
| forget_gate = F.Activation(slice_gates[1], act_type="sigmoid", |
| name='%sf'%name) |
| in_transform = F.Activation(slice_gates[2], act_type="tanh", |
| name='%sc'%name) |
| out_gate = F.Activation(slice_gates[3], act_type="sigmoid", |
| name='%so'%name) |
| next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform, |
| name='%sstate'%name) |
| next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type="tanh"), |
| name='%sout'%name) |
| |
| return next_h, [next_h, next_c] |
| |
| |
| class GRUCell(RecurrentCell): |
| """Gated Rectified Unit (GRU) network cell. |
| Note: this is an implementation of the cuDNN version of GRUs |
| (slight modification compared to Cho et al. 2014). |
| |
| Parameters |
| ---------- |
| num_hidden : int |
| number of units in output symbol |
| prefix : str, default 'gru_' |
| prefix for name of layers |
| (and name of weight if params is None) |
| params : Parameter or None |
| container for weight sharing between cells. |
| created if None. |
| """ |
| def __init__(self, num_hidden, num_input=0, prefix=None, params=None): |
| super(GRUCell, self).__init__(prefix=prefix, params=params) |
| self._num_hidden = num_hidden |
| self.i2h_weight = self.params.get('i2h_weight', shape=(3*num_hidden, num_input)) |
| self.h2h_weight = self.params.get('h2h_weight', shape=(3*num_hidden, num_hidden)) |
| self.i2h_bias = self.params.get('i2h_bias', shape=(3*num_hidden)) |
| self.h2h_bias = self.params.get('h2h_bias', shape=(3*num_hidden)) |
| |
| def state_info(self, batch_size=0): |
| return [{'shape': (batch_size, self._num_hidden), '__layout__': 'NC'}] |
| |
| @property |
| def _gate_names(self): |
| return ['_r', '_z', '_o'] |
| |
| def _alias(self): |
| return 'gru' |
| |
| def forward(self, F, inputs, states, i2h_weight, i2h_bias, |
| h2h_weight, h2h_bias): |
| # pylint: disable=too-many-locals |
| name = self._curr_prefix |
| prev_state_h = states[0] |
| i2h = F.FullyConnected(data=inputs, |
| weight=i2h_weight, |
| bias=i2h_bias, |
| num_hidden=self._num_hidden * 3, |
| name="%si2h" % name) |
| h2h = F.FullyConnected(data=prev_state_h, |
| weight=h2h_weight, |
| bias=h2h_bias, |
| num_hidden=self._num_hidden * 3, |
| name="%sh2h" % name) |
| |
| i2h_r, i2h_z, i2h = F.SliceChannel(i2h, num_outputs=3, name="%si2h_slice" % name) |
| h2h_r, h2h_z, h2h = F.SliceChannel(h2h, num_outputs=3, name="%sh2h_slice" % name) |
| |
| reset_gate = F.Activation(i2h_r + h2h_r, act_type="sigmoid", |
| name="%sr_act" % name) |
| update_gate = F.Activation(i2h_z + h2h_z, act_type="sigmoid", |
| name="%sz_act" % name) |
| |
| next_h_tmp = F.Activation(i2h + reset_gate * h2h, act_type="tanh", |
| name="%sh_act" % name) |
| |
| next_h = F._internal._plus((1. - update_gate) * next_h_tmp, update_gate * prev_state_h, |
| name='%sout' % name) |
| |
| return next_h, [next_h] |
| |
| |
| class FusedRNNCell(RecurrentCell): |
| """Fusing RNN layers across time step into one kernel. |
| Improves speed but is less flexible. Currently only |
| supported if using cuDNN on GPU. |
| |
| Parameters |
| ---------- |
| """ |
| def __init__(self, num_hidden, num_layers=1, mode='lstm', bidirectional=False, |
| dropout=0., get_next_state=False, forget_bias=1.0, num_input=0, |
| prefix=None, params=None): |
| self._num_hidden = num_hidden |
| self._num_layers = num_layers |
| self._mode = mode |
| self._bidirectional = bidirectional |
| self._dropout = dropout |
| self._get_next_state = get_next_state |
| self._directions = ['l', 'r'] if bidirectional else ['l'] |
| super(FusedRNNCell, self).__init__(prefix=prefix, params=params) |
| |
| initializer = init.FusedRNN(None, num_hidden, num_layers, mode, |
| bidirectional, forget_bias) |
| self.parameters = self.params.get('parameters', init=initializer, |
| shape=(self._num_input_to_size(num_input),)) |
| |
| def state_info(self, batch_size=0): |
| b = self._bidirectional + 1 |
| n = (self._mode == 'lstm') + 1 |
| return [{'shape': (b*self._num_layers, batch_size, self._num_hidden), |
| '__layout__': 'LNC'} for _ in range(n)] |
| |
| @property |
| def _gate_names(self): |
| return {'rnn_relu': [''], |
| 'rnn_tanh': [''], |
| 'lstm': ['_i', '_f', '_c', '_o'], |
| 'gru': ['_r', '_z', '_o']}[self._mode] |
| |
| @property |
| def _num_gates(self): |
| return len(self._gate_names) |
| |
| def _alias(self): |
| return self._mode |
| |
| def _size_to_num_input(self, size): |
| b = len(self._directions) |
| m = self._num_gates |
| h = self._num_hidden |
| return size//b//h//m - (self._num_layers - 1)*(h+b*h+2) - h - 2 |
| |
| def _num_input_to_size(self, num_input): |
| if num_input == 0: |
| return 0 |
| b = self._bidirectional + 1 |
| m = self._num_gates |
| h = self._num_hidden |
| return (num_input+h+2)*h*m*b + (self._num_layers-1)*m*h*(h+b*h+2)*b |
| |
| def _slice_weights(self, arr, li, lh): |
| """slice fused rnn weights""" |
| args = {} |
| gate_names = self._gate_names |
| directions = self._directions |
| |
| b = len(directions) |
| p = 0 |
| for layer in range(self._num_layers): |
| for direction in directions: |
| for gate in gate_names: |
| name = '%s%s%d_i2h%s_weight'%(self._prefix, direction, layer, gate) |
| if layer > 0: |
| size = b*lh*lh |
| args[name] = arr[p:p+size].reshape((lh, b*lh)) |
| else: |
| size = li*lh |
| args[name] = arr[p:p+size].reshape((lh, li)) |
| p += size |
| for gate in gate_names: |
| name = '%s%s%d_h2h%s_weight'%(self._prefix, direction, layer, gate) |
| size = lh**2 |
| args[name] = arr[p:p+size].reshape((lh, lh)) |
| p += size |
| |
| for layer in range(self._num_layers): |
| for direction in directions: |
| for gate in gate_names: |
| name = '%s%s%d_i2h%s_bias'%(self._prefix, direction, layer, gate) |
| args[name] = arr[p:p+lh] |
| p += lh |
| for gate in gate_names: |
| name = '%s%s%d_h2h%s_bias'%(self._prefix, direction, layer, gate) |
| args[name] = arr[p:p+lh] |
| p += lh |
| |
| assert p == arr.size, "Invalid parameters size for FusedRNNCell" |
| return args |
| |
| def unpack_weights(self, args): |
| args = args.copy() |
| arr = args.pop(self.parameters.name) |
| num_input = self._size_to_num_input(arr.size) |
| nargs = self._slice_weights(arr, num_input, self._num_hidden) |
| args.update({name: nd.copy() for name, nd in nargs.items()}) |
| return args |
| |
| def pack_weights(self, args): |
| args = args.copy() |
| w0 = args['%sl0_i2h%s_weight'%(self._prefix, self._gate_names[0])] |
| num_input = w0.shape[1] |
| total = self._num_input_to_size(num_input) |
| |
| arr = ndarray.zeros((total,), ctx=w0.context, dtype=w0.dtype) |
| for name, nd in self._slice_weights(arr, num_input, self._num_hidden).items(): |
| nd[:] = args.pop(name) |
| args[self.parameters.name] = arr |
| return args |
| |
| def __call__(self, inputs, states): |
| raise NotImplementedError("FusedRNNCell cannot be stepped. Please use unroll") |
| |
| def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None): |
| self.reset() |
| |
| inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, True) |
| if axis == 1: |
| warnings.warn("NTC layout detected. Consider using " |
| "TNC for FusedRNNCell for faster speed") |
| inputs = F.swapaxes(inputs, dim1=0, dim2=1) |
| else: |
| assert axis == 0, "Unsupported layout %s"%layout |
| begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size) |
| |
| states = begin_state |
| if self._mode == 'lstm': |
| states = {'state': states[0], 'state_cell': states[1]} # pylint: disable=redefined-variable-type |
| else: |
| states = {'state': states[0]} |
| |
| if isinstance(inputs, symbol.Symbol): |
| parameters = self.parameters.var() |
| else: |
| parameters = self.parameters.data(inputs.context) |
| |
| rnn = F.RNN(data=inputs, parameters=parameters, |
| state_size=self._num_hidden, num_layers=self._num_layers, |
| bidirectional=self._bidirectional, p=self._dropout, |
| state_outputs=self._get_next_state, |
| mode=self._mode, name=self._prefix+'rnn', |
| **states) |
| |
| if not self._get_next_state: |
| outputs, states = rnn, [] |
| elif self._mode == 'lstm': |
| outputs, states = rnn[0], [rnn[1], rnn[2]] |
| else: |
| outputs, states = rnn[0], [rnn[1]] |
| |
| if axis == 1: |
| outputs = F.swapaxes(outputs, dim1=0, dim2=1) |
| |
| outputs, _, _, _ = _format_sequence(length, outputs, layout, merge_outputs) |
| |
| return outputs, states |
| |
| def unfuse(self): |
| """Unfuse the fused RNN in to a stack of rnn cells. |
| |
| Returns |
| ------- |
| cell : SequentialRNNCell |
| unfused cell that can be used for stepping, and can run on CPU. |
| """ |
| stack = SequentialRNNCell() |
| get_cell = {'rnn_relu': lambda cell_prefix: RNNCell(self._num_hidden, |
| activation='relu', |
| prefix=cell_prefix), |
| 'rnn_tanh': lambda cell_prefix: RNNCell(self._num_hidden, |
| activation='tanh', |
| prefix=cell_prefix), |
| 'lstm': lambda cell_prefix: LSTMCell(self._num_hidden, |
| prefix=cell_prefix), |
| 'gru': lambda cell_prefix: GRUCell(self._num_hidden, |
| prefix=cell_prefix)}[self._mode] |
| for i in range(self._num_layers): |
| if self._bidirectional: |
| stack.add(BidirectionalCell( |
| get_cell('%sl%d_'%(self._prefix, i)), |
| get_cell('%sr%d_'%(self._prefix, i)), |
| output_prefix='%sbi_l%d_'%(self._prefix, i))) |
| else: |
| stack.add(get_cell('%sl%d_'%(self._prefix, i))) |
| |
| if self._dropout > 0 and i != self._num_layers - 1: |
| stack.add(DropoutCell(self._dropout, prefix='%s_dropout%d_'%(self._prefix, i))) |
| |
| return stack |
| |
| |
| class SequentialRNNCell(RecurrentCell): |
| """Sequantially stacking multiple RNN cells.""" |
| def __init__(self): |
| super(SequentialRNNCell, self).__init__(prefix='', params=None) |
| |
| def add(self, cell): |
| """Append a cell into the stack. |
| |
| Parameters |
| ---------- |
| cell : rnn cell |
| """ |
| self.register_child(cell) |
| |
| def state_info(self, batch_size=0): |
| return _cells_state_info(self._children, batch_size) |
| |
| def begin_state(self, **kwargs): |
| assert not self._modified, \ |
| "After applying modifier cells (e.g. ZoneoutCell) the base " \ |
| "cell cannot be called directly. Call the modifier cell instead." |
| return _cells_begin_state(self._children, **kwargs) |
| |
| def unpack_weights(self, args): |
| return _cells_unpack_weights(self._children, args) |
| |
| def pack_weights(self, args): |
| return _cells_pack_weights(self._children, args) |
| |
| def __call__(self, inputs, states): |
| self._counter += 1 |
| next_states = [] |
| p = 0 |
| for cell in self._children: |
| assert not isinstance(cell, BidirectionalCell) |
| n = len(cell.state_info()) |
| state = states[p:p+n] |
| p += n |
| inputs, state = cell(inputs, state) |
| next_states.append(state) |
| return inputs, sum(next_states, []) |
| |
| def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None): |
| self.reset() |
| |
| inputs, _, F, batch_size = _format_sequence(length, inputs, layout, None) |
| num_cells = len(self._children) |
| begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size) |
| |
| p = 0 |
| next_states = [] |
| for i, cell in enumerate(self._children): |
| n = len(cell.state_info()) |
| states = begin_state[p:p+n] |
| p += n |
| inputs, states = cell.unroll(length, inputs=inputs, begin_state=states, layout=layout, |
| merge_outputs=None if i < num_cells-1 else merge_outputs) |
| next_states.extend(states) |
| |
| return inputs, next_states |
| |
| def forward(self, *args, **kwargs): |
| raise NotImplementedError |
| |
| |
| class DropoutCell(RecurrentCell): |
| """Apply dropout on input. |
| |
| Parameters |
| ---------- |
| dropout : float |
| percentage of elements to drop out, which |
| is 1 - percentage to retain. |
| """ |
| def __init__(self, dropout, prefix=None, params=None): |
| super(DropoutCell, self).__init__(prefix, params) |
| assert isinstance(dropout, numeric_types), "dropout probability must be a number" |
| self.dropout = dropout |
| |
| def state_info(self, batch_size=0): |
| return [] |
| |
| def _alias(self): |
| return 'dropout' |
| |
| def forward(self, F, inputs, states): |
| if self.dropout > 0: |
| inputs = F.Dropout(data=inputs, p=self.dropout) |
| return inputs, states |
| |
| def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None): |
| self.reset() |
| |
| inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs) |
| if isinstance(inputs, tensor_types): |
| return self.forward(F, inputs, begin_state if begin_state else []) |
| else: |
| return super(DropoutCell, self).unroll( |
| length, inputs, begin_state=begin_state, layout=layout, |
| merge_outputs=merge_outputs) |
| |
| |
| class ModifierCell(RecurrentCell): |
| """Base class for modifier cells. A modifier |
| cell takes a base cell, apply modifications |
| on it (e.g. Zoneout), and returns a new cell. |
| |
| After applying modifiers the base cell should |
| no longer be called directly. The modifer cell |
| should be used instead. |
| """ |
| def __init__(self, base_cell): |
| super(ModifierCell, self).__init__(prefix=None, params=None) |
| base_cell._modified = True |
| self.base_cell = base_cell |
| |
| @property |
| def params(self): |
| self._own_params = False |
| return self.base_cell.params |
| |
| def state_info(self, batch_size=0): |
| return self.base_cell.state_info(batch_size) |
| |
| def begin_state(self, func=symbol.zeros, **kwargs): |
| assert not self._modified, \ |
| "After applying modifier cells (e.g. DropoutCell) the base " \ |
| "cell cannot be called directly. Call the modifier cell instead." |
| self.base_cell._modified = False |
| begin = self.base_cell.begin_state(func=func, **kwargs) |
| self.base_cell._modified = True |
| return begin |
| |
| def unpack_weights(self, args): |
| return self.base_cell.unpack_weights(args) |
| |
| def pack_weights(self, args): |
| return self.base_cell.pack_weights(args) |
| |
| def forward(self, F, inputs, states): |
| raise NotImplementedError |
| |
| |
| class ZoneoutCell(ModifierCell): |
| """Apply Zoneout on base cell.""" |
| def __init__(self, base_cell, zoneout_outputs=0., zoneout_states=0.): |
| assert not isinstance(base_cell, FusedRNNCell), \ |
| "FusedRNNCell doesn't support zoneout. " \ |
| "Please unfuse first." |
| assert not isinstance(base_cell, BidirectionalCell), \ |
| "BidirectionalCell doesn't support zoneout since it doesn't support step. " \ |
| "Please add ZoneoutCell to the cells underneath instead." |
| assert not isinstance(base_cell, SequentialRNNCell) or not base_cell._bidirectional, \ |
| "Bidirectional SequentialRNNCell doesn't support zoneout. " \ |
| "Please add ZoneoutCell to the cells underneath instead." |
| super(ZoneoutCell, self).__init__(base_cell) |
| self.zoneout_outputs = zoneout_outputs |
| self.zoneout_states = zoneout_states |
| self.prev_output = None |
| |
| def _alias(self): |
| return 'zoneout' |
| |
| def reset(self): |
| super(ZoneoutCell, self).reset() |
| self.prev_output = None |
| |
| def forward(self, F, inputs, states): |
| cell, p_outputs, p_states = self.base_cell, self.zoneout_outputs, self.zoneout_states |
| next_output, next_states = cell(inputs, states) |
| mask = (lambda p, like: F.Dropout(F.ones_like(like), p=p)) |
| |
| prev_output = self.prev_output |
| if prev_output is None: |
| prev_output = F.zeros_like(next_output) |
| |
| output = (F.where(mask(p_outputs, next_output), next_output, prev_output) |
| if p_outputs != 0. else next_output) |
| states = ([F.where(mask(p_states, new_s), new_s, old_s) for new_s, old_s in |
| zip(next_states, states)] if p_states != 0. else next_states) |
| |
| self.prev_output = output |
| |
| return output, states |
| |
| |
| class ResidualCell(ModifierCell): |
| """ |
| Adds residual connection as described in Wu et al, 2016 |
| (https://arxiv.org/abs/1609.08144). |
| Output of the cell is output of the base cell plus input. |
| """ |
| |
| def __init__(self, base_cell): |
| super(ResidualCell, self).__init__(base_cell) |
| |
| def forward(self, F, inputs, states): |
| output, states = self.base_cell(inputs, states) |
| output = F.elemwise_add(output, inputs, name="%s_plus_residual" % output.name) |
| return output, states |
| |
| def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None): |
| self.reset() |
| |
| self.base_cell._modified = False |
| outputs, states = self.base_cell.unroll(length, inputs=inputs, begin_state=begin_state, |
| layout=layout, merge_outputs=merge_outputs) |
| self.base_cell._modified = True |
| |
| merge_outputs = isinstance(outputs, tensor_types) if merge_outputs is None else \ |
| merge_outputs |
| inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs) |
| if merge_outputs: |
| outputs = F.elemwise_add(outputs, inputs) |
| else: |
| outputs = [F.elemwise_add(i, j) for i, j in zip(outputs, inputs)] |
| |
| return outputs, states |
| |
| |
| class BidirectionalCell(RecurrentCell): |
| """Bidirectional RNN cell. |
| |
| Parameters |
| ---------- |
| l_cell : RecurrentCell |
| cell for forward unrolling |
| r_cell : RecurrentCell |
| cell for backward unrolling |
| """ |
| def __init__(self, l_cell, r_cell, output_prefix='bi_'): |
| super(BidirectionalCell, self).__init__(prefix='', params=None) |
| self.register_child(l_cell) |
| self.register_child(r_cell) |
| self._output_prefix = output_prefix |
| |
| def unpack_weights(self, args): |
| return _cells_unpack_weights(self._children, args) |
| |
| def pack_weights(self, args): |
| return _cells_pack_weights(self._children, args) |
| |
| def __call__(self, inputs, states): |
| raise NotImplementedError("Bidirectional cannot be stepped. Please use unroll") |
| |
| def state_info(self, batch_size=0): |
| return _cells_state_info(self._children, batch_size) |
| |
| def begin_state(self, **kwargs): |
| assert not self._modified, \ |
| "After applying modifier cells (e.g. DropoutCell) the base " \ |
| "cell cannot be called directly. Call the modifier cell instead." |
| return _cells_begin_state(self._children, **kwargs) |
| |
| def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None): |
| self.reset() |
| |
| inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False) |
| begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size) |
| |
| states = begin_state |
| l_cell, r_cell = self._children |
| l_outputs, l_states = l_cell.unroll(length, inputs=inputs, |
| begin_state=states[:len(l_cell.state_info(batch_size))], |
| layout=layout, merge_outputs=merge_outputs) |
| r_outputs, r_states = r_cell.unroll(length, |
| inputs=list(reversed(inputs)), |
| begin_state=states[len(l_cell.state_info(batch_size)):], |
| layout=layout, merge_outputs=merge_outputs) |
| |
| if merge_outputs is None: |
| merge_outputs = (isinstance(l_outputs, tensor_types) |
| and isinstance(r_outputs, tensor_types)) |
| l_outputs, _, _, _ = _format_sequence(None, l_outputs, layout, merge_outputs) |
| r_outputs, _, _, _ = _format_sequence(None, r_outputs, layout, merge_outputs) |
| |
| if merge_outputs: |
| r_outputs = F.reverse(r_outputs, axis=axis) |
| outputs = F.concat(l_outputs, r_outputs, dim=2, name='%sout'%self._output_prefix) |
| else: |
| outputs = [F.concat(l_o, r_o, dim=1, name='%st%d'%(self._output_prefix, i)) |
| for i, (l_o, r_o) in enumerate(zip(l_outputs, reversed(r_outputs)))] |
| |
| states = [l_states, r_states] |
| return outputs, states |