| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "# Licensed to the Apache Software Foundation (ASF) under one\n", |
| "# or more contributor license agreements. See the NOTICE file\n", |
| "# distributed with this work for additional information\n", |
| "# regarding copyright ownership. The ASF licenses this file\n", |
| "# to you under the Apache License, Version 2.0 (the\n", |
| "# \"License\"); you may not use this file except in compliance\n", |
| "# with the License. You may obtain a copy of the License at\n", |
| "#\n", |
| "# http://www.apache.org/licenses/LICENSE-2.0\n", |
| "#\n", |
| "# Unless required by applicable law or agreed to in writing,\n", |
| "# software distributed under the License is distributed on an\n", |
| "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", |
| "# KIND, either express or implied. See the License for the\n", |
| "# specific language governing permissions and limitations\n", |
| "# under the License." |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "# Using a bi-lstm to sort a sequence of integers" |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 1, |
| "source": [ |
| "import random\n", |
| "import string\n", |
| "\n", |
| "import mxnet as mx\n", |
| "from mxnet import gluon, np\n", |
| "import numpy as onp" |
| ], |
| "outputs": [], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Data Preparation" |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 2, |
| "source": [ |
| "max_num = 999\n", |
| "dataset_size = 60000\n", |
| "seq_len = 5\n", |
| "split = 0.8\n", |
| "batch_size = 512\n", |
| "ctx = mx.gpu() if mx.device.num_gpus() > 0 else mx.cpu()" |
| ], |
| "outputs": [], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "We are getting a dataset of **dataset_size** sequences of integers of length **seq_len** between **0** and **max_num**. We use **split*100%** of them for training and the rest for testing.\n", |
| "\n", |
| "\n", |
| "For example:\n", |
| "\n", |
| "50 10 200 999 30\n", |
| "\n", |
| "Should return\n", |
| "\n", |
| "10 30 50 200 999" |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "source": [ |
| "X = mx.np.random.uniform(low=0, high=max_num, size=(dataset_size, seq_len)).astype('int32').asnumpy()\n", |
| "Y = X.copy()\n", |
| "Y.sort() #Let's sort X to get the target" |
| ], |
| "outputs": [], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "source": [ |
| "print(\"Input {}\\nTarget {}\".format(X[0].tolist(), Y[0].tolist()))" |
| ], |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Input [548, 592, 714, 843, 602]\n", |
| "Target [548, 592, 602, 714, 843]\n" |
| ] |
| } |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "For the purpose of training, we encode the input as characters rather than numbers" |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "source": [ |
| "vocab = string.digits + \" \"\n", |
| "print(vocab)\n", |
| "vocab_idx = { c:i for i,c in enumerate(vocab)}\n", |
| "print(vocab_idx)" |
| ], |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "0123456789 \n", |
| "{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, ' ': 10}\n" |
| ] |
| } |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "We write a transform that will convert our numbers into text of maximum length **max_len**, and one-hot encode the characters.\n", |
| "For example:\n", |
| "\n", |
| "\"30 10\" corresponding indices are [3, 0, 10, 1, 0]\n", |
| "\n", |
| "We then one hot encode that and get a matrix representation of our input. We don't need to encode our target as the loss we are going to use support sparse labels" |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "source": [ |
| "max_len = len(str(max_num))*seq_len+(seq_len-1)\n", |
| "print(\"Maximum length of the string: %s\" % max_len)" |
| ], |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Maximum length of the string: 19\n" |
| ] |
| } |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 7, |
| "source": [ |
| "def transform(x, y):\n", |
| " x_string = ' '.join(map(str, x.tolist()))\n", |
| " x_string_padded = x_string + ' '*(max_len-len(x_string))\n", |
| " x = [vocab_idx[c] for c in x_string_padded]\n", |
| " y_string = ' '.join(map(str, y.tolist()))\n", |
| " y_string_padded = y_string + ' '*(max_len-len(y_string))\n", |
| " y = [vocab_idx[c] for c in y_string_padded]\n", |
| " return mx.npx.one_hot(mx.nd.array(x), len(vocab)), mx.np.array(y)" |
| ], |
| "outputs": [], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 8, |
| "source": [ |
| "split_idx = int(split*len(X))\n", |
| "train_dataset = gluon.data.ArrayDataset(X[:split_idx], Y[:split_idx]).transform(transform)\n", |
| "test_dataset = gluon.data.ArrayDataset(X[split_idx:], Y[split_idx:]).transform(transform)" |
| ], |
| "outputs": [], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 9, |
| "source": [ |
| "print(\"Input {}\".format(X[0]))\n", |
| "print(\"Transformed data Input {}\".format(train_dataset[0][0]))\n", |
| "print(\"Target {}\".format(Y[0]))\n", |
| "print(\"Transformed data Target {}\".format(train_dataset[0][1]))" |
| ], |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Input [548 592 714 843 602]\n", |
| "Transformed data Input \n", |
| "[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", |
| " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", |
| " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", |
| " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", |
| " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", |
| " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", |
| " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", |
| " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", |
| " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", |
| " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", |
| " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", |
| " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", |
| " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", |
| " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", |
| " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", |
| " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", |
| " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", |
| " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", |
| " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]]\n", |
| "<NDArray 19x11 @cpu(0)>\n", |
| "Target [548 592 602 714 843]\n", |
| "Transformed data Target \n", |
| "[ 5. 4. 8. 10. 5. 9. 2. 10. 6. 0. 2. 10. 7. 1. 4. 10. 8. 4.\n", |
| " 3.]\n", |
| "<NDArray 19 @cpu(0)>\n" |
| ] |
| } |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 10, |
| "source": [ |
| "train_data = gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=20, last_batch='rollover')\n", |
| "test_data = gluon.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=5, last_batch='rollover')" |
| ], |
| "outputs": [], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Creating the network" |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 11, |
| "source": [ |
| "net = gluon.nn.HybridSequential()\n", |
| "net.add(\n", |
| " gluon.rnn.LSTM(hidden_size=128, num_layers=2, layout='NTC', bidirectional=True),\n", |
| " gluon.nn.Dense(len(vocab), flatten=False)\n", |
| ")" |
| ], |
| "outputs": [], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 12, |
| "source": [ |
| "net.initialize(mx.init.Xavier(), ctx=ctx)" |
| ], |
| "outputs": [], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 13, |
| "source": [ |
| "loss = gluon.loss.SoftmaxCELoss()" |
| ], |
| "outputs": [], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "We use a learning rate schedule to improve the convergence of the model" |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 14, |
| "source": [ |
| "schedule = mx.lr_scheduler.FactorScheduler(step=len(train_data)*10, factor=0.75)\n", |
| "schedule.base_lr = 0.01" |
| ], |
| "outputs": [], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 15, |
| "source": [ |
| "trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':0.01, 'lr_scheduler':schedule})" |
| ], |
| "outputs": [], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Training loop" |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 16, |
| "source": [ |
| "epochs = 100\n", |
| "for e in range(epochs):\n", |
| " epoch_loss = 0.\n", |
| " for i, (data, label) in enumerate(train_data):\n", |
| " data = data.as_in_context(ctx)\n", |
| " label = label.as_in_context(ctx)\n", |
| "\n", |
| " with mx.autograd.record():\n", |
| " output = net(data)\n", |
| " l = loss(output, label)\n", |
| "\n", |
| " l.backward()\n", |
| " trainer.step(data.shape[0])\n", |
| " \n", |
| " epoch_loss += l.mean()\n", |
| " \n", |
| " print(\"Epoch [{}] Loss: {}, LR {}\".format(e, epoch_loss.item()/(i+1), trainer.learning_rate))" |
| ], |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Epoch [0] Loss: 1.6627886372227823, LR 0.01\n", |
| "Epoch [1] Loss: 1.210370733382854, LR 0.01\n", |
| "Epoch [2] Loss: 0.9692377131035987, LR 0.01\n", |
| "Epoch [3] Loss: 0.7976046623067653, LR 0.01\n", |
| "Epoch [4] Loss: 0.5714595343476983, LR 0.01\n", |
| "Epoch [5] Loss: 0.4458411196444897, LR 0.01\n", |
| "Epoch [6] Loss: 0.36039798817736035, LR 0.01\n", |
| "Epoch [7] Loss: 0.32665719377233626, LR 0.01\n", |
| "Epoch [8] Loss: 0.262064205702915, LR 0.01\n", |
| "Epoch [9] Loss: 0.22285924059279422, LR 0.0075\n", |
| "Epoch [10] Loss: 0.19018426854559717, LR 0.0075\n", |
| "Epoch [11] Loss: 0.1718730723604243, LR 0.0075\n", |
| "Epoch [12] Loss: 0.15736752171670237, LR 0.0075\n", |
| "Epoch [13] Loss: 0.14579375246737866, LR 0.0075\n", |
| "Epoch [14] Loss: 0.13546599733068587, LR 0.0075\n", |
| "Epoch [15] Loss: 0.12490207590955368, LR 0.0075\n", |
| "Epoch [16] Loss: 0.11803316300915133, LR 0.0075\n", |
| "Epoch [17] Loss: 0.10653189395336395, LR 0.0075\n", |
| "Epoch [18] Loss: 0.10514750379197141, LR 0.0075\n", |
| "Epoch [19] Loss: 0.09590611559279422, LR 0.005625\n", |
| "Epoch [20] Loss: 0.08146028108494256, LR 0.005625\n", |
| "Epoch [21] Loss: 0.07707348782965477, LR 0.005625\n", |
| "Epoch [22] Loss: 0.07206193436967566, LR 0.005625\n", |
| "Epoch [23] Loss: 0.07001185417175293, LR 0.005625\n", |
| "Epoch [24] Loss: 0.06797058351578252, LR 0.005625\n", |
| "Epoch [25] Loss: 0.0649358110224947, LR 0.005625\n", |
| "Epoch [26] Loss: 0.06219124286732775, LR 0.005625\n", |
| "Epoch [27] Loss: 0.06075144828634059, LR 0.005625\n", |
| "Epoch [28] Loss: 0.05711334495134251, LR 0.005625\n", |
| "Epoch [29] Loss: 0.054747099572039666, LR 0.00421875\n", |
| "Epoch [30] Loss: 0.0441775271233092, LR 0.00421875\n", |
| "Epoch [31] Loss: 0.041551097910454936, LR 0.00421875\n", |
| "Epoch [32] Loss: 0.04095017269093503, LR 0.00421875\n", |
| "Epoch [33] Loss: 0.04045371045457556, LR 0.00421875\n", |
| "Epoch [34] Loss: 0.038867686657195394, LR 0.00421875\n", |
| "Epoch [35] Loss: 0.038131744303601854, LR 0.00421875\n", |
| "Epoch [36] Loss: 0.039834817250569664, LR 0.00421875\n", |
| "Epoch [37] Loss: 0.03669035941996473, LR 0.00421875\n", |
| "Epoch [38] Loss: 0.03373505967728635, LR 0.00421875\n", |
| "Epoch [39] Loss: 0.03164981273894615, LR 0.0031640625\n", |
| "Epoch [40] Loss: 0.025532766055035336, LR 0.0031640625\n", |
| "Epoch [41] Loss: 0.022659448867148543, LR 0.0031640625\n", |
| "Epoch [42] Loss: 0.02307056112492338, LR 0.0031640625\n", |
| "Epoch [43] Loss: 0.02236944056571798, LR 0.0031640625\n", |
| "Epoch [44] Loss: 0.022204211963120328, LR 0.0031640625\n", |
| "Epoch [45] Loss: 0.02262336903430046, LR 0.0031640625\n", |
| "Epoch [46] Loss: 0.02253308448385685, LR 0.0031640625\n", |
| "Epoch [47] Loss: 0.025286573044797207, LR 0.0031640625\n", |
| "Epoch [48] Loss: 0.02439300988310127, LR 0.0031640625\n", |
| "Epoch [49] Loss: 0.017976388018181983, LR 0.002373046875\n", |
| "Epoch [50] Loss: 0.014343131095805067, LR 0.002373046875\n", |
| "Epoch [51] Loss: 0.013039355582379281, LR 0.002373046875\n", |
| "Epoch [52] Loss: 0.011884741885687715, LR 0.002373046875\n", |
| "Epoch [53] Loss: 0.011438189668858305, LR 0.002373046875\n", |
| "Epoch [54] Loss: 0.011447292693117832, LR 0.002373046875\n", |
| "Epoch [55] Loss: 0.014212571560068334, LR 0.002373046875\n", |
| "Epoch [56] Loss: 0.019900493724371797, LR 0.002373046875\n", |
| "Epoch [57] Loss: 0.02102568301748722, LR 0.002373046875\n", |
| "Epoch [58] Loss: 0.01346214400961044, LR 0.002373046875\n", |
| "Epoch [59] Loss: 0.010107964911359422, LR 0.0017797851562500002\n", |
| "Epoch [60] Loss: 0.008353193600972494, LR 0.0017797851562500002\n", |
| "Epoch [61] Loss: 0.007678258292218472, LR 0.0017797851562500002\n", |
| "Epoch [62] Loss: 0.007262124660167288, LR 0.0017797851562500002\n", |
| "Epoch [63] Loss: 0.00705223578087827, LR 0.0017797851562500002\n", |
| "Epoch [64] Loss: 0.006788556293774677, LR 0.0017797851562500002\n", |
| "Epoch [65] Loss: 0.006473606571238091, LR 0.0017797851562500002\n", |
| "Epoch [66] Loss: 0.006206096486842378, LR 0.0017797851562500002\n", |
| "Epoch [67] Loss: 0.00584477313021396, LR 0.0017797851562500002\n", |
| "Epoch [68] Loss: 0.005648705267137097, LR 0.0017797851562500002\n", |
| "Epoch [69] Loss: 0.006481769871204458, LR 0.0013348388671875003\n", |
| "Epoch [70] Loss: 0.008430448618341, LR 0.0013348388671875003\n", |
| "Epoch [71] Loss: 0.006877245421105242, LR 0.0013348388671875003\n", |
| "Epoch [72] Loss: 0.005671108281740578, LR 0.0013348388671875003\n", |
| "Epoch [73] Loss: 0.004832422162624116, LR 0.0013348388671875003\n", |
| "Epoch [74] Loss: 0.004441103402604448, LR 0.0013348388671875003\n", |
| "Epoch [75] Loss: 0.004216198591475791, LR 0.0013348388671875003\n", |
| "Epoch [76] Loss: 0.004041922989711967, LR 0.0013348388671875003\n", |
| "Epoch [77] Loss: 0.003937713643337818, LR 0.0013348388671875003\n", |
| "Epoch [78] Loss: 0.010251983049068046, LR 0.0013348388671875003\n", |
| "Epoch [79] Loss: 0.01829354052848004, LR 0.0010011291503906252\n", |
| "Epoch [80] Loss: 0.006723233448561802, LR 0.0010011291503906252\n", |
| "Epoch [81] Loss: 0.004397524798170049, LR 0.0010011291503906252\n", |
| "Epoch [82] Loss: 0.0038475305476087206, LR 0.0010011291503906252\n", |
| "Epoch [83] Loss: 0.003591177945441388, LR 0.0010011291503906252\n", |
| "Epoch [84] Loss: 0.003425112014175743, LR 0.0010011291503906252\n", |
| "Epoch [85] Loss: 0.0032633850549129728, LR 0.0010011291503906252\n", |
| "Epoch [86] Loss: 0.0031762316505959693, LR 0.0010011291503906252\n", |
| "Epoch [87] Loss: 0.0030452777096565734, LR 0.0010011291503906252\n", |
| "Epoch [88] Loss: 0.002950224184220837, LR 0.0010011291503906252\n", |
| "Epoch [89] Loss: 0.002821172171450676, LR 0.0007508468627929689\n", |
| "Epoch [90] Loss: 0.002725780961361337, LR 0.0007508468627929689\n", |
| "Epoch [91] Loss: 0.002660556359493986, LR 0.0007508468627929689\n", |
| "Epoch [92] Loss: 0.0026011724946319414, LR 0.0007508468627929689\n", |
| "Epoch [93] Loss: 0.0025355776256703317, LR 0.0007508468627929689\n", |
| "Epoch [94] Loss: 0.0024825221997626283, LR 0.0007508468627929689\n", |
| "Epoch [95] Loss: 0.0024245587435174497, LR 0.0007508468627929689\n", |
| "Epoch [96] Loss: 0.002365282145879602, LR 0.0007508468627929689\n", |
| "Epoch [97] Loss: 0.0023112583984719946, LR 0.0007508468627929689\n", |
| "Epoch [98] Loss: 0.002257173682780976, LR 0.0007508468627929689\n", |
| "Epoch [99] Loss: 0.002162747085094452, LR 0.0005631351470947267\n" |
| ] |
| } |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Testing" |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "We get a random element from the testing set" |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 17, |
| "source": [ |
| "n = random.randint(0, len(test_data)-1)\n", |
| "\n", |
| "x_orig = X[split_idx+n]\n", |
| "y_orig = Y[split_idx+n]" |
| ], |
| "outputs": [], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 41, |
| "source": [ |
| "def get_pred(x):\n", |
| " x, _ = transform(x, x)\n", |
| " output = net(mx.np.expand_dims(x.to_device(ctx), axis=0))\n", |
| "\n", |
| " # Convert output back to string\n", |
| " pred = ''.join([vocab[int(o)] for o in output[0].argmax(axis=1).asnumpy().tolist()])\n", |
| " return pred" |
| ], |
| "outputs": [], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "Printing the result" |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 43, |
| "source": [ |
| "x_ = ' '.join(map(str,x_orig))\n", |
| "label = ' '.join(map(str,y_orig))\n", |
| "print(\"X {}\\nPredicted {}\\nLabel {}\".format(x_, get_pred(x_orig), label))" |
| ], |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "X 611 671 275 871 944\n", |
| "Predicted 275 611 671 871 944\n", |
| "Label 275 611 671 871 944\n" |
| ] |
| } |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "We can also pick our own example, and the network manages to sort it without problem:" |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 66, |
| "source": [ |
| "print(get_pred(onp.array([500, 30, 999, 10, 130])))" |
| ], |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "10 30 130 500 999 \n" |
| ] |
| } |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "The model has even learned to generalize to examples not on the training set" |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 64, |
| "source": [ |
| "print(\"Only four numbers:\", get_pred(onp.array([105, 302, 501, 202])))" |
| ], |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Only four numbers: 105 202 302 501 \n" |
| ] |
| } |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "However we can see it has trouble with other edge cases:" |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 63, |
| "source": [ |
| "print(\"Small digits:\", get_pred(onp.array([10, 3, 5, 2, 8])))\n", |
| "print(\"Small digits, 6 numbers:\", get_pred(onp.array([10, 33, 52, 21, 82, 10])))" |
| ], |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Small digits: 8 0 42 28 \n", |
| "Small digits, 6 numbers: 10 0 20 82 71 115 \n" |
| ] |
| } |
| ], |
| "metadata": {} |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "This could be improved by adjusting the training dataset accordingly" |
| ], |
| "metadata": {} |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 3", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.6.4" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 2 |
| } |