blob: bc7f7b471d9435df70cb75104c222ddc3aa1d9f3 [file] [view]
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
Generating Random Sentence with LSTM RNN
========================================
This tutorial shows how to train a LSTM (Long short-term memory) RNN
(recurrent neural network) to perform character-level sequence training
and prediction. The original model, usually called `char-rnn` is
described in [Andrej Karpathy's
blog](http://karpathy.github.io/2015/05/21/rnn-effectiveness/), with a
reference implementation in Torch available
[here](https://github.com/karpathy/char-rnn).
Because MXNet.jl does not have a specialized model for recurrent neural
networks yet, the example shown here is an implementation of LSTM by
using the default FeedForward model via explicitly unfolding over time.
We will be using fixed-length input sequence for training. The code is
adapted from the [char-rnn example for MXNet's Python
binding](https://github.com/dmlc/mxnet/blob/master/example/rnn/char_lstm.ipynb),
which demonstrates how to use low-level
[Symbolic API](@ref) to build customized neural
network models directly.
The most important code snippets of this example is shown and explained
here. To see and run the complete code, please refer to the
[examples/char-lstm](https://github.com/dmlc/MXNet.jl/tree/master/examples/char-lstm)
directory. You will need to install
[Iterators.jl](https://github.com/JuliaLang/Iterators.jl) and
[StatsBase.jl](https://github.com/JuliaStats/StatsBase.jl) to run this
example.
LSTM Cells
----------
Christopher Olah has a [great blog post about LSTM](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) with
beautiful and clear illustrations. So we will not repeat the definition
and explanation of what an LSTM cell is here. Basically, an LSTM cell
takes input `x`, as well as previous states (including `c` and `h`), and
produce the next states. We define a helper type to bundle the two state
variables together:
Because LSTM weights are shared at every time when we do explicit
unfolding, so we also define a helper type to hold all the weights (and
bias) for an LSTM cell for convenience.
Note all the variables are of type SymbolicNode. We will construct the
LSTM network as a symbolic computation graph, which is then instantiated
with NDArray for actual computation.
The following figure is stolen (permission requested) from [Christopher
Olah's blog](http://colah.github.io/posts/2015-08-Understanding-LSTMs/),
which illustrate exactly what the code snippet above is doing.
![image](images/LSTM3-chain.png)
In particular, instead of defining the four gates independently, we do
the computation together and then use SliceChannel to split them into
four outputs. The computation of gates are all done with the symbolic
API. The return value is a LSTM state containing the output of a LSTM
cell.
Unfolding LSTM
--------------
Using the LSTM cell defined above, we are now ready to define a function
to unfold a LSTM network with L layers and T time steps. The first part
of the function is just defining all the symbolic variables for the
shared weights and states.
The `embed_W` is the weights used for character embedding --- i.e.
mapping the one-hot encoded characters into real vectors. The `pred_W`
and `pred_b` are weights and bias for the final prediction at each time
step.
Then we define the weights for each LSTM cell. Note there is one cell
for each layer, and it will be replicated (unrolled) over time. The
states are, however, *not* shared over time. Instead, here we define the
initial states here at the beginning of a sequence, and we will update
them with the output states at each time step as we explicitly unroll
the LSTM.
Unrolling over time is a straightforward procedure of stacking the
embedding layer, and then LSTM cells, on top of which the prediction
layer. During unrolling, we update the states and collect all the
outputs. Note each time step takes data and label as inputs. If the LSTM
is named as `:ptb`, the data and label at step `t` will be named
`:ptb_data_$t` and `:ptb_label_$t`. Late on when we prepare the data, we
will define the data provider to match those names.
Note at each time step, the prediction is connected to a SoftmaxOutput
operator, which could back propagate when corresponding labels are
provided. The states are then connected to the next time step, which
allows back propagate through time. However, at the end of the sequence,
the final states are not connected to anything. This dangling outputs is
problematic, so we explicitly connect each of them to a BlockGrad
operator, which simply back propagates 0-gradient and closes the
computation graph.
In the end, we just group all the prediction outputs at each time step
as a single SymbolicNode and return. Optionally we will also group the
final states, this is used when we use the trained LSTM to sample
sentences.
Data Provider for Text Sequences
--------------------------------
Now we need to construct a data provider that takes a text file, divide
the text into mini-batches of fixed-length character-sequences, and
provide them as one-hot encoded vectors.
Note the is no fancy feature extraction at all. Each character is simply
encoded as a one-hot vector: a 0-1 vector of the size given by the
vocabulary. Here we just construct the vocabulary by collecting all the
unique characters in the training text -- there are not too many of them
(including punctuations and whitespace) for English text. Each input
character is then encoded as a vector of 0s on all coordinates, and 1 on
the coordinate corresponding to that character. The
character-to-coordinate mapping is giving by the vocabulary.
The text sequence data provider implements the [Data Providers](@ref) api. We define the `CharSeqProvider` as below:
The provided data and labels follow the naming convention of inputs used
when unrolling the LSTM. Note in the code below, apart from
`$name_data_$t` and `$name_label_$t`, we also provides the initial `c`
and `h` states for each layer. This is because we are using the
high-level FeedForward API, which has no idea about time and states. So
we will feed the initial states for each sequence from the data
provider. Since the initial states is always zero, we just need to
always provide constant zero blobs.
Next we implement the `eachbatch` method from the [`mx.AbstractDataProvider`](@ref) interface for the
provider. We start by defining the data and label arrays, and the
`DataBatch` object we will provide in each iteration.
The actual data providing iteration is implemented as a Julia
**coroutine**. In this way, we can write the data loading logic as a
simple coherent `for` loop, and do not need to implement the interface
functions like Base.start, Base.next, etc.
Basically, we partition the text into batches, each batch containing
several contiguous text sequences. Note at each time step, the LSTM is
trained to predict the next character, so the label is the same as the
data, but shifted ahead by one index.
Training the LSTM
-----------------
Now we have implemented all the supporting infrastructures for our
char-lstm. To train the model, we just follow the standard high-level
API. Firstly, we construct a LSTM symbolic architecture:
Note all the parameters are defined in
[examples/char-lstm/config.jl](https://github.com/dmlc/MXNet.jl/blob/master/examples/char-lstm/config.jl).
Now we load the text file and define the data provider. The data
`input.txt` we used in this example is [a tiny Shakespeare
dataset](https://github.com/dmlc/web-data/tree/master/mxnet/tinyshakespeare).
But you can try with other text files.
The last step is to construct a model, an optimizer and fit the mode to
the data. We are using the ADAM optimizer \[Adam\]\_ in this example.
Note we are also using a customized `NLL` evaluation metric, which
calculate the negative log-likelihood during training. Here is an output
sample at the end of the training process.
```
...
INFO: Speed: 357.72 samples/sec
INFO: == Epoch 020 ==========
INFO: ## Training summary
INFO: NLL = 1.4672
INFO: perplexity = 4.3373
INFO: time = 87.2631 seconds
INFO: ## Validation summary
INFO: NLL = 1.6374
INFO: perplexity = 5.1418
INFO: Saved checkpoint to 'char-lstm/checkpoints/ptb-0020.params'
INFO: Speed: 368.74 samples/sec
INFO: Speed: 361.04 samples/sec
INFO: Speed: 360.02 samples/sec
INFO: Speed: 362.34 samples/sec
INFO: Speed: 360.80 samples/sec
INFO: Speed: 362.77 samples/sec
INFO: Speed: 357.18 samples/sec
INFO: Speed: 355.30 samples/sec
INFO: Speed: 362.33 samples/sec
INFO: Speed: 359.23 samples/sec
INFO: Speed: 358.09 samples/sec
INFO: Speed: 356.89 samples/sec
INFO: Speed: 371.91 samples/sec
INFO: Speed: 372.24 samples/sec
INFO: Speed: 356.59 samples/sec
INFO: Speed: 356.64 samples/sec
INFO: Speed: 360.24 samples/sec
INFO: Speed: 360.32 samples/sec
INFO: Speed: 362.38 samples/sec
INFO: == Epoch 021 ==========
INFO: ## Training summary
INFO: NLL = 1.4655
INFO: perplexity = 4.3297
INFO: time = 86.9243 seconds
INFO: ## Validation summary
INFO: NLL = 1.6366
INFO: perplexity = 5.1378
INFO: Saved checkpoint to 'examples/char-lstm/checkpoints/ptb-0021.params'
```
Sampling Random Sentences
-------------------------
After training the LSTM, we can now sample random sentences from the
trained model. The sampler works in the following way:
- Starting from some fixed character, take `a` for example, and feed
it as input to the LSTM.
- The LSTM will produce an output distribution over the vocabulary and
a state in the first time step. We sample a character from the
output distribution, fix it as the second character.
- In the next time step, we feed the previously sampled character as
input and continue running the LSTM by also taking the previous
states (instead of the 0 initial states).
- Continue running until we sampled enough characters.
Note we are running with mini-batches, so several sentences could be
sampled simultaneously. Here are some sampled outputs from a network I
trained for around half an hour on the Shakespeare dataset. Note all the
line-breaks, punctuations and upper-lower case letters are produced by
the sampler itself. I did not do any post-processing.
```
## Sample 1
all have sir,
Away will fill'd in His time, I'll keep her, do not madam, if they here? Some more ha?
## Sample 2
am.
CLAUDIO:
Hone here, let her, the remedge, and I know not slept a likely, thou some soully free?
## Sample 3
arrel which noble thing
The exchnachsureding worns: I ne'er drunken Biancas, fairer, than the lawfu?
## Sample 4
augh assalu, you'ld tell me corn;
Farew. First, for me of a loved. Has thereat I knock you presents?
## Sample 5
ame the first answer.
MARIZARINIO:
Door of Angelo as her lord, shrield liken Here fellow the fool ?
## Sample 6
ad well.
CLAUDIO:
Soon him a fellows here; for her fine edge in a bogms' lord's wife.
LUCENTIO:
I?
## Sample 7
adrezilian measure.
LUCENTIO:
So, help'd you hath nes have a than dream's corn, beautio, I perchas?
## Sample 8
as eatter me;
The girlly: and no other conciolation!
BISTRUMIO:
I have be rest girl. O, that I a h?
## Sample 9
and is intend you sort:
What held her all 'clama's for maffice. Some servant.' what I say me the cu?
## Sample 10
an thoughts will said in our pleasue,
Not scanin on him that you live; believaries she.
ISABELLLLL?
```
See [Andrej Karpathy's blog
post](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) on more
examples and links including Linux source codes, Algebraic Geometry
Theorems, and even cooking recipes. The code for sampling can be found
in
[examples/char-lstm/sampler.jl](https://github.com/apache/incubator-mxnet/tree/master/julia/examples/char-lstm/sampler.jl).
Visualizing the LSTM
--------------------
Finally, you could visualize the LSTM by calling to\_graphviz on the
constructed LSTM symbolic architecture. We only show an example of
1-layer and 2-time-step LSTM below. The automatic layout produced by
GraphViz is definitely much less clear than [Christopher Olah's
illustrations](http://colah.github.io/posts/2015-08-Understanding-LSTMs/),
but could otherwise be very useful for debugging. As we can see, the
LSTM unfolded over time is just a (very) deep neural network. The
complete code for producing this visualization can be found in
[examples/char-lstm/visualize.jl](https://github.com/apache/incubator-mxnet/tree/master/julia/examples/char-lstmvisualize.jl).
![image](images/char-lstm-vis.svg)