blob: 1ec56cd9aa8c5abfe374ed06d97c873466eb1ebc [file] [log] [blame]
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import mxnet as mx\n",
"import numpy as np\n",
"import random\n",
"import bisect"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# set up logging\n",
"import logging\n",
"reload(logging)\n",
"logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.DEBUG, datefmt='%I:%M:%S')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# A Glance of LSTM structure and embedding layer\n",
"\n",
"We will build a LSTM network to learn from char only. At each time, input is a char. We will see this LSTM is able to learn words and grammers from sequence of chars.\n",
"\n",
"The following figure is showing an unrolled LSTM network, and how we generate embedding of a char. The one-hot to embedding operation is a special case of fully connected network.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"http://data.mxnet.io/mxnet/data/char-rnn_1.png\">\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"http://data.mxnet.io/mxnet/data/char-rnn_2.png\">"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from lstm import lstm_unroll, lstm_inference_symbol\n",
"from bucket_io import BucketSentenceIter\n",
"from rnn_model import LSTMInferenceModel"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Read from doc\n",
"def read_content(path):\n",
" with open(path) as ins:\n",
" content = ins.read()\n",
" return content\n",
"\n",
"# Build a vocabulary of what char we have in the content\n",
"def build_vocab(path):\n",
" content = read_content(path)\n",
" content = list(content)\n",
" idx = 1 # 0 is left for zero-padding\n",
" the_vocab = {}\n",
" for word in content:\n",
" if len(word) == 0:\n",
" continue\n",
" if not word in the_vocab:\n",
" the_vocab[word] = idx\n",
" idx += 1\n",
" return the_vocab\n",
"\n",
"# We will assign each char with a special numerical id\n",
"def text2id(sentence, the_vocab):\n",
" words = list(sentence)\n",
" words = [the_vocab[w] for w in words if len(w) > 0]\n",
" return words"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Evaluation \n",
"def Perplexity(label, pred):\n",
" label = label.T.reshape((-1,))\n",
" loss = 0.\n",
" for i in range(pred.shape[0]):\n",
" loss += -np.log(max(1e-10, pred[i][int(label[i])]))\n",
" return np.exp(loss / label.size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Get Data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"data_url = \"http://data.mxnet.io/mxnet/data/char_lstm.zip\"\n",
"os.system(\"wget %s\" % data_url)\n",
"os.system(\"unzip -o char_lstm.zip\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Sample training data:\n",
"```\n",
"all to Renewal Keynote Address Call to Renewal Pt 1Call to Renewal Part 2 TOPIC: Our Past, Our Future & Vision for America June\n",
"28, 2006 Call to Renewal' Keynote Address Complete Text Good morning. I appreciate the opportunity to speak here at the Call to R\n",
"enewal's Building a Covenant for a New America conference. I've had the opportunity to take a look at your Covenant for a New Ame\n",
"rica. It is filled with outstanding policies and prescriptions for much of what ails this country. So I'd like to congratulate yo\n",
"u all on the thoughtful presentations you've given so far about poverty and justice in America, and for putting fire under the fe\n",
"et of the political leadership here in Washington.But today I'd like to talk about the connection between religion and politics a\n",
"nd perhaps offer some thoughts about how we can sort through some of the often bitter arguments that we've been seeing over the l\n",
"ast several years.I do so because, as you all know, we can affirm the importance of poverty in the Bible; and we can raise up and\n",
" pass out this Covenant for a New America. We can talk to the press, and we can discuss the religious call to address poverty and\n",
" environmental stewardship all we want, but it won't have an impact unless we tackle head-on the mutual suspicion that sometimes\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# LSTM Hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# The batch size for training\n",
"batch_size = 32\n",
"# We can support various length input\n",
"# For this problem, we cut each input sentence to length of 129\n",
"# So we only need fix length bucket\n",
"buckets = [129]\n",
"# hidden unit in LSTM cell\n",
"num_hidden = 512\n",
"# embedding dimension, which is, map a char to a 256 dim vector\n",
"num_embed = 256\n",
"# number of lstm layer\n",
"num_lstm_layer = 3"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# we will show a quick demo in 2 epoch\n",
"# and we will see result by training 75 epoch\n",
"num_epoch = 2\n",
"# learning rate \n",
"learning_rate = 0.01\n",
"# we will use pure sgd without momentum\n",
"momentum = 0.0"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# we can select multi-gpu for training\n",
"# for this demo we only use one\n",
"devs = [mx.context.gpu(i) for i in range(1)]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# build char vocabluary from input\n",
"vocab = build_vocab(\"./obama.txt\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# generate symbol for a length\n",
"def sym_gen(seq_len):\n",
" return lstm_unroll(num_lstm_layer, seq_len, len(vocab) + 1,\n",
" num_hidden=num_hidden, num_embed=num_embed,\n",
" num_label=len(vocab) + 1, dropout=0.2)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# initalize states for LSTM\n",
"init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]\n",
"init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]\n",
"init_states = init_c + init_h"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Summary of dataset ==================\n",
"bucket of len 129 : 8290 samples\n"
]
}
],
"source": [
"# we can build an iterator for text\n",
"data_train = BucketSentenceIter(\"./obama.txt\", vocab, buckets, batch_size,\n",
" init_states, seperate_char='\\n',\n",
" text2id=text2id, read_content=read_content)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# the network symbol\n",
"symbol = sym_gen(buckets[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train model"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Train a LSTM network as simple as feedforward network\n",
"model = mx.model.FeedForward(ctx=devs,\n",
" symbol=symbol,\n",
" num_epoch=num_epoch,\n",
" learning_rate=learning_rate,\n",
" momentum=momentum,\n",
" wd=0.0001,\n",
" initializer=mx.init.Xavier(factor_type=\"in\", magnitude=2.34))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"05:01:35 INFO:Start training with [gpu(0)]\n"
]
}
],
"source": [
"# Fit it\n",
"model.fit(X=data_train,\n",
" eval_metric = mx.metric.np(Perplexity),\n",
" batch_end_callback=mx.callback.Speedometer(batch_size, 50),\n",
" epoch_end_callback=mx.callback.do_checkpoint(\"obama\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Inference from model"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# helper strcuture for prediction\n",
"def MakeRevertVocab(vocab):\n",
" dic = {}\n",
" for k, v in vocab.items():\n",
" dic[v] = k\n",
" return dic"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# make input from char\n",
"def MakeInput(char, vocab, arr):\n",
" idx = vocab[char]\n",
" tmp = np.zeros((1,))\n",
" tmp[0] = idx\n",
" arr[:] = tmp"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# helper function for random sample \n",
"def _cdf(weights):\n",
" total = sum(weights)\n",
" result = []\n",
" cumsum = 0\n",
" for w in weights:\n",
" cumsum += w\n",
" result.append(cumsum / total)\n",
" return result\n",
"\n",
"def _choice(population, weights):\n",
" assert len(population) == len(weights)\n",
" cdf_vals = _cdf(weights)\n",
" x = random.random()\n",
" idx = bisect.bisect(cdf_vals, x)\n",
" return population[idx]\n",
"\n",
"# we can use random output or fixed output by choosing largest probability\n",
"def MakeOutput(prob, vocab, sample=False, temperature=1.):\n",
" if sample == False:\n",
" idx = np.argmax(prob, axis=1)[0]\n",
" else:\n",
" fix_dict = [\"\"] + [vocab[i] for i in range(1, len(vocab) + 1)]\n",
" scale_prob = np.clip(prob, 1e-6, 1 - 1e-6)\n",
" rescale = np.exp(np.log(scale_prob) / temperature)\n",
" rescale[:] /= rescale.sum()\n",
" return _choice(fix_dict, rescale[0, :])\n",
" try:\n",
" char = vocab[idx]\n",
" except:\n",
" char = ''\n",
" return char"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# load from check-point\n",
"_, arg_params, __ = mx.model.load_checkpoint(\"obama\", 75)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# build an inference model\n",
"model = LSTMInferenceModel(num_lstm_layer, len(vocab) + 1,\n",
" num_hidden=num_hidden, num_embed=num_embed,\n",
" num_label=len(vocab) + 1, arg_params=arg_params, ctx=mx.gpu(), dropout=0.2)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# generate a sequence of 1200 chars\n",
"\n",
"seq_length = 1200\n",
"input_ndarray = mx.nd.zeros((1,))\n",
"revert_vocab = MakeRevertVocab(vocab)\n",
"# Feel free to change the starter sentence\n",
"output ='The joke'\n",
"random_sample = True\n",
"new_sentence = True\n",
"\n",
"ignore_length = len(output)\n",
"\n",
"for i in range(seq_length):\n",
" if i <= ignore_length - 1:\n",
" MakeInput(output[i], vocab, input_ndarray)\n",
" else:\n",
" MakeInput(output[-1], vocab, input_ndarray)\n",
" prob = model.forward(input_ndarray, new_sentence)\n",
" new_sentence = False\n",
" next_char = MakeOutput(prob, revert_vocab, random_sample)\n",
" if next_char == '':\n",
" new_sentence = True\n",
" if i >= ignore_length - 1:\n",
" output += next_char\n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The joke learning to be struggle for our daughter. We are the ones who can't pay their relationship. The Judiciary Commencement ce designed to deficit to the party of almost unemployment instead, just to look at home, little proof for America, Carguin are showing struggle against our pride. That if you came from tharger by a party that would increase the pervasive sense of new global warming against the challenge of governments - to get a corporation.As a highealth care, your own retirement security information about his family decided to get a job or aspect what will allow cannot simply by sagging high school system and stin twenty-five years. But led my faith designed to leave all their buddets and responsibility. But I sund this dangerous weapons, explain withdrawal oful -clears axdication in Iraq.What is the time for American policy became their efforts, and given them that a man doesn't make sure that that my own, you'll be faced with you. Four years, reforms illness all that kind of choose to understand is a broadeary. You instills in search of a reducithis recision, of us, with public services from using that barealies, but that must continue to limb line, they know th\n"
]
}
],
"source": [
"# Let's see what we can learned from char in Obama's speech.\n",
"print(output)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}