blob: cc0e43b99b669f66a47071691774ed0fe9f883e1 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Licensed to the Apache Software Foundation (ASF) under one\n",
"# or more contributor license agreements. See the NOTICE file\n",
"# distributed with this work for additional information\n",
"# regarding copyright ownership. The ASF licenses this file\n",
"# to you under the Apache License, Version 2.0 (the\n",
"# \"License\"); you may not use this file except in compliance\n",
"# with the License. You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing,\n",
"# software distributed under the License is distributed on an\n",
"# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
"# KIND, either express or implied. See the License for the\n",
"# specific language governing permissions and limitations\n",
"# under the License."
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"# Using a bi-lstm to sort a sequence of integers"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 1,
"source": [
"import random\n",
"import string\n",
"\n",
"import mxnet as mx\n",
"from mxnet import gluon, np\n",
"import numpy as onp"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Data Preparation"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"max_num = 999\n",
"dataset_size = 60000\n",
"seq_len = 5\n",
"split = 0.8\n",
"batch_size = 512\n",
"ctx = mx.gpu() if mx.device.num_gpus() > 0 else mx.cpu()"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"We are getting a dataset of **dataset_size** sequences of integers of length **seq_len** between **0** and **max_num**. We use **split*100%** of them for training and the rest for testing.\n",
"\n",
"\n",
"For example:\n",
"\n",
"50 10 200 999 30\n",
"\n",
"Should return\n",
"\n",
"10 30 50 200 999"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"source": [
"X = mx.np.random.uniform(low=0, high=max_num, size=(dataset_size, seq_len)).astype('int32').asnumpy()\n",
"Y = X.copy()\n",
"Y.sort() #Let's sort X to get the target"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 4,
"source": [
"print(\"Input {}\\nTarget {}\".format(X[0].tolist(), Y[0].tolist()))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Input [548, 592, 714, 843, 602]\n",
"Target [548, 592, 602, 714, 843]\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"For the purpose of training, we encode the input as characters rather than numbers"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"source": [
"vocab = string.digits + \" \"\n",
"print(vocab)\n",
"vocab_idx = { c:i for i,c in enumerate(vocab)}\n",
"print(vocab_idx)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0123456789 \n",
"{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, ' ': 10}\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"We write a transform that will convert our numbers into text of maximum length **max_len**, and one-hot encode the characters.\n",
"For example:\n",
"\n",
"\"30 10\" corresponding indices are [3, 0, 10, 1, 0]\n",
"\n",
"We then one hot encode that and get a matrix representation of our input. We don't need to encode our target as the loss we are going to use support sparse labels"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"source": [
"max_len = len(str(max_num))*seq_len+(seq_len-1)\n",
"print(\"Maximum length of the string: %s\" % max_len)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Maximum length of the string: 19\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"source": [
"def transform(x, y):\n",
" x_string = ' '.join(map(str, x.tolist()))\n",
" x_string_padded = x_string + ' '*(max_len-len(x_string))\n",
" x = [vocab_idx[c] for c in x_string_padded]\n",
" y_string = ' '.join(map(str, y.tolist()))\n",
" y_string_padded = y_string + ' '*(max_len-len(y_string))\n",
" y = [vocab_idx[c] for c in y_string_padded]\n",
" return mx.npx.one_hot(mx.nd.array(x), len(vocab)), mx.np.array(y)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"source": [
"split_idx = int(split*len(X))\n",
"train_dataset = gluon.data.ArrayDataset(X[:split_idx], Y[:split_idx]).transform(transform)\n",
"test_dataset = gluon.data.ArrayDataset(X[split_idx:], Y[split_idx:]).transform(transform)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 9,
"source": [
"print(\"Input {}\".format(X[0]))\n",
"print(\"Transformed data Input {}\".format(train_dataset[0][0]))\n",
"print(\"Target {}\".format(Y[0]))\n",
"print(\"Transformed data Target {}\".format(train_dataset[0][1]))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Input [548 592 714 843 602]\n",
"Transformed data Input \n",
"[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
" [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
" [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
" [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
" [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n",
" [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
" [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
" [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
" [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
" [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n",
" [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n",
" [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
" [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]]\n",
"<NDArray 19x11 @cpu(0)>\n",
"Target [548 592 602 714 843]\n",
"Transformed data Target \n",
"[ 5. 4. 8. 10. 5. 9. 2. 10. 6. 0. 2. 10. 7. 1. 4. 10. 8. 4.\n",
" 3.]\n",
"<NDArray 19 @cpu(0)>\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"source": [
"train_data = gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=20, last_batch='rollover')\n",
"test_data = gluon.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=5, last_batch='rollover')"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Creating the network"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 11,
"source": [
"net = gluon.nn.HybridSequential()\n",
"net.add(\n",
" gluon.rnn.LSTM(hidden_size=128, num_layers=2, layout='NTC', bidirectional=True),\n",
" gluon.nn.Dense(len(vocab), flatten=False)\n",
")"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 12,
"source": [
"net.initialize(mx.init.Xavier(), ctx=ctx)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 13,
"source": [
"loss = gluon.loss.SoftmaxCELoss()"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"We use a learning rate schedule to improve the convergence of the model"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 14,
"source": [
"schedule = mx.lr_scheduler.FactorScheduler(step=len(train_data)*10, factor=0.75)\n",
"schedule.base_lr = 0.01"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 15,
"source": [
"trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':0.01, 'lr_scheduler':schedule})"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Training loop"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 16,
"source": [
"epochs = 100\n",
"for e in range(epochs):\n",
" epoch_loss = 0.\n",
" for i, (data, label) in enumerate(train_data):\n",
" data = data.as_in_context(ctx)\n",
" label = label.as_in_context(ctx)\n",
"\n",
" with mx.autograd.record():\n",
" output = net(data)\n",
" l = loss(output, label)\n",
"\n",
" l.backward()\n",
" trainer.step(data.shape[0])\n",
" \n",
" epoch_loss += l.mean()\n",
" \n",
" print(\"Epoch [{}] Loss: {}, LR {}\".format(e, epoch_loss.item()/(i+1), trainer.learning_rate))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch [0] Loss: 1.6627886372227823, LR 0.01\n",
"Epoch [1] Loss: 1.210370733382854, LR 0.01\n",
"Epoch [2] Loss: 0.9692377131035987, LR 0.01\n",
"Epoch [3] Loss: 0.7976046623067653, LR 0.01\n",
"Epoch [4] Loss: 0.5714595343476983, LR 0.01\n",
"Epoch [5] Loss: 0.4458411196444897, LR 0.01\n",
"Epoch [6] Loss: 0.36039798817736035, LR 0.01\n",
"Epoch [7] Loss: 0.32665719377233626, LR 0.01\n",
"Epoch [8] Loss: 0.262064205702915, LR 0.01\n",
"Epoch [9] Loss: 0.22285924059279422, LR 0.0075\n",
"Epoch [10] Loss: 0.19018426854559717, LR 0.0075\n",
"Epoch [11] Loss: 0.1718730723604243, LR 0.0075\n",
"Epoch [12] Loss: 0.15736752171670237, LR 0.0075\n",
"Epoch [13] Loss: 0.14579375246737866, LR 0.0075\n",
"Epoch [14] Loss: 0.13546599733068587, LR 0.0075\n",
"Epoch [15] Loss: 0.12490207590955368, LR 0.0075\n",
"Epoch [16] Loss: 0.11803316300915133, LR 0.0075\n",
"Epoch [17] Loss: 0.10653189395336395, LR 0.0075\n",
"Epoch [18] Loss: 0.10514750379197141, LR 0.0075\n",
"Epoch [19] Loss: 0.09590611559279422, LR 0.005625\n",
"Epoch [20] Loss: 0.08146028108494256, LR 0.005625\n",
"Epoch [21] Loss: 0.07707348782965477, LR 0.005625\n",
"Epoch [22] Loss: 0.07206193436967566, LR 0.005625\n",
"Epoch [23] Loss: 0.07001185417175293, LR 0.005625\n",
"Epoch [24] Loss: 0.06797058351578252, LR 0.005625\n",
"Epoch [25] Loss: 0.0649358110224947, LR 0.005625\n",
"Epoch [26] Loss: 0.06219124286732775, LR 0.005625\n",
"Epoch [27] Loss: 0.06075144828634059, LR 0.005625\n",
"Epoch [28] Loss: 0.05711334495134251, LR 0.005625\n",
"Epoch [29] Loss: 0.054747099572039666, LR 0.00421875\n",
"Epoch [30] Loss: 0.0441775271233092, LR 0.00421875\n",
"Epoch [31] Loss: 0.041551097910454936, LR 0.00421875\n",
"Epoch [32] Loss: 0.04095017269093503, LR 0.00421875\n",
"Epoch [33] Loss: 0.04045371045457556, LR 0.00421875\n",
"Epoch [34] Loss: 0.038867686657195394, LR 0.00421875\n",
"Epoch [35] Loss: 0.038131744303601854, LR 0.00421875\n",
"Epoch [36] Loss: 0.039834817250569664, LR 0.00421875\n",
"Epoch [37] Loss: 0.03669035941996473, LR 0.00421875\n",
"Epoch [38] Loss: 0.03373505967728635, LR 0.00421875\n",
"Epoch [39] Loss: 0.03164981273894615, LR 0.0031640625\n",
"Epoch [40] Loss: 0.025532766055035336, LR 0.0031640625\n",
"Epoch [41] Loss: 0.022659448867148543, LR 0.0031640625\n",
"Epoch [42] Loss: 0.02307056112492338, LR 0.0031640625\n",
"Epoch [43] Loss: 0.02236944056571798, LR 0.0031640625\n",
"Epoch [44] Loss: 0.022204211963120328, LR 0.0031640625\n",
"Epoch [45] Loss: 0.02262336903430046, LR 0.0031640625\n",
"Epoch [46] Loss: 0.02253308448385685, LR 0.0031640625\n",
"Epoch [47] Loss: 0.025286573044797207, LR 0.0031640625\n",
"Epoch [48] Loss: 0.02439300988310127, LR 0.0031640625\n",
"Epoch [49] Loss: 0.017976388018181983, LR 0.002373046875\n",
"Epoch [50] Loss: 0.014343131095805067, LR 0.002373046875\n",
"Epoch [51] Loss: 0.013039355582379281, LR 0.002373046875\n",
"Epoch [52] Loss: 0.011884741885687715, LR 0.002373046875\n",
"Epoch [53] Loss: 0.011438189668858305, LR 0.002373046875\n",
"Epoch [54] Loss: 0.011447292693117832, LR 0.002373046875\n",
"Epoch [55] Loss: 0.014212571560068334, LR 0.002373046875\n",
"Epoch [56] Loss: 0.019900493724371797, LR 0.002373046875\n",
"Epoch [57] Loss: 0.02102568301748722, LR 0.002373046875\n",
"Epoch [58] Loss: 0.01346214400961044, LR 0.002373046875\n",
"Epoch [59] Loss: 0.010107964911359422, LR 0.0017797851562500002\n",
"Epoch [60] Loss: 0.008353193600972494, LR 0.0017797851562500002\n",
"Epoch [61] Loss: 0.007678258292218472, LR 0.0017797851562500002\n",
"Epoch [62] Loss: 0.007262124660167288, LR 0.0017797851562500002\n",
"Epoch [63] Loss: 0.00705223578087827, LR 0.0017797851562500002\n",
"Epoch [64] Loss: 0.006788556293774677, LR 0.0017797851562500002\n",
"Epoch [65] Loss: 0.006473606571238091, LR 0.0017797851562500002\n",
"Epoch [66] Loss: 0.006206096486842378, LR 0.0017797851562500002\n",
"Epoch [67] Loss: 0.00584477313021396, LR 0.0017797851562500002\n",
"Epoch [68] Loss: 0.005648705267137097, LR 0.0017797851562500002\n",
"Epoch [69] Loss: 0.006481769871204458, LR 0.0013348388671875003\n",
"Epoch [70] Loss: 0.008430448618341, LR 0.0013348388671875003\n",
"Epoch [71] Loss: 0.006877245421105242, LR 0.0013348388671875003\n",
"Epoch [72] Loss: 0.005671108281740578, LR 0.0013348388671875003\n",
"Epoch [73] Loss: 0.004832422162624116, LR 0.0013348388671875003\n",
"Epoch [74] Loss: 0.004441103402604448, LR 0.0013348388671875003\n",
"Epoch [75] Loss: 0.004216198591475791, LR 0.0013348388671875003\n",
"Epoch [76] Loss: 0.004041922989711967, LR 0.0013348388671875003\n",
"Epoch [77] Loss: 0.003937713643337818, LR 0.0013348388671875003\n",
"Epoch [78] Loss: 0.010251983049068046, LR 0.0013348388671875003\n",
"Epoch [79] Loss: 0.01829354052848004, LR 0.0010011291503906252\n",
"Epoch [80] Loss: 0.006723233448561802, LR 0.0010011291503906252\n",
"Epoch [81] Loss: 0.004397524798170049, LR 0.0010011291503906252\n",
"Epoch [82] Loss: 0.0038475305476087206, LR 0.0010011291503906252\n",
"Epoch [83] Loss: 0.003591177945441388, LR 0.0010011291503906252\n",
"Epoch [84] Loss: 0.003425112014175743, LR 0.0010011291503906252\n",
"Epoch [85] Loss: 0.0032633850549129728, LR 0.0010011291503906252\n",
"Epoch [86] Loss: 0.0031762316505959693, LR 0.0010011291503906252\n",
"Epoch [87] Loss: 0.0030452777096565734, LR 0.0010011291503906252\n",
"Epoch [88] Loss: 0.002950224184220837, LR 0.0010011291503906252\n",
"Epoch [89] Loss: 0.002821172171450676, LR 0.0007508468627929689\n",
"Epoch [90] Loss: 0.002725780961361337, LR 0.0007508468627929689\n",
"Epoch [91] Loss: 0.002660556359493986, LR 0.0007508468627929689\n",
"Epoch [92] Loss: 0.0026011724946319414, LR 0.0007508468627929689\n",
"Epoch [93] Loss: 0.0025355776256703317, LR 0.0007508468627929689\n",
"Epoch [94] Loss: 0.0024825221997626283, LR 0.0007508468627929689\n",
"Epoch [95] Loss: 0.0024245587435174497, LR 0.0007508468627929689\n",
"Epoch [96] Loss: 0.002365282145879602, LR 0.0007508468627929689\n",
"Epoch [97] Loss: 0.0023112583984719946, LR 0.0007508468627929689\n",
"Epoch [98] Loss: 0.002257173682780976, LR 0.0007508468627929689\n",
"Epoch [99] Loss: 0.002162747085094452, LR 0.0005631351470947267\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Testing"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"We get a random element from the testing set"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 17,
"source": [
"n = random.randint(0, len(test_data)-1)\n",
"\n",
"x_orig = X[split_idx+n]\n",
"y_orig = Y[split_idx+n]"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 41,
"source": [
"def get_pred(x):\n",
" x, _ = transform(x, x)\n",
" output = net(mx.np.expand_dims(x.to_device(ctx), axis=0))\n",
"\n",
" # Convert output back to string\n",
" pred = ''.join([vocab[int(o)] for o in output[0].argmax(axis=1).asnumpy().tolist()])\n",
" return pred"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Printing the result"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 43,
"source": [
"x_ = ' '.join(map(str,x_orig))\n",
"label = ' '.join(map(str,y_orig))\n",
"print(\"X {}\\nPredicted {}\\nLabel {}\".format(x_, get_pred(x_orig), label))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"X 611 671 275 871 944\n",
"Predicted 275 611 671 871 944\n",
"Label 275 611 671 871 944\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"We can also pick our own example, and the network manages to sort it without problem:"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 66,
"source": [
"print(get_pred(onp.array([500, 30, 999, 10, 130])))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"10 30 130 500 999 \n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"The model has even learned to generalize to examples not on the training set"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 64,
"source": [
"print(\"Only four numbers:\", get_pred(onp.array([105, 302, 501, 202])))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Only four numbers: 105 202 302 501 \n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"However we can see it has trouble with other edge cases:"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 63,
"source": [
"print(\"Small digits:\", get_pred(onp.array([10, 3, 5, 2, 8])))\n",
"print(\"Small digits, 6 numbers:\", get_pred(onp.array([10, 33, 52, 21, 82, 10])))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Small digits: 8 0 42 28 \n",
"Small digits, 6 numbers: 10 0 20 82 71 115 \n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"This could be improved by adjusting the training dataset accordingly"
],
"metadata": {}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}