| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "id": "896ea900", |
| "metadata": {}, |
| "source": [ |
| "<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n", |
| "<!--- or more contributor license agreements. See the NOTICE file -->\n", |
| "<!--- distributed with this work for additional information -->\n", |
| "<!--- regarding copyright ownership. The ASF licenses this file -->\n", |
| "<!--- to you under the Apache License, Version 2.0 (the -->\n", |
| "<!--- \"License\"); you may not use this file except in compliance -->\n", |
| "<!--- with the License. You may obtain a copy of the License at -->\n", |
| "\n", |
| "<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n", |
| "\n", |
| "<!--- Unless required by applicable law or agreed to in writing, -->\n", |
| "<!--- software distributed under the License is distributed on an -->\n", |
| "<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n", |
| "<!--- KIND, either express or implied. See the License for the -->\n", |
| "<!--- specific language governing permissions and limitations -->\n", |
| "<!--- under the License. -->\n", |
| "\n", |
| "# Step 6: Train a Neural Network\n", |
| "\n", |
| "Now that you have seen all the necessary components for creating a neural network, you are\n", |
| "now ready to put all the pieces together and train a model end to end.\n", |
| "\n", |
| "## 1. Data preparation\n", |
| "\n", |
| "The typical process for creating and training a model starts with loading and\n", |
| "preparing the datasets. For this Network you will use a [dataset of leaf\n", |
| "images](https://data.mendeley.com/datasets/hb74ynkjcn/1) that consists of healthy\n", |
| "and diseased examples of leafs from twelve different plant species. To get this\n", |
| "dataset you have to download and extract it with the following commands." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "1f0a533a", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Import all the necessary libraries to train\n", |
| "import time\n", |
| "import os\n", |
| "import zipfile\n", |
| "\n", |
| "import mxnet as mx\n", |
| "from mxnet import np, npx, gluon, init, autograd\n", |
| "from mxnet.gluon import nn\n", |
| "from mxnet.gluon.data.vision import transforms\n", |
| "\n", |
| "import matplotlib.pyplot as plt\n", |
| "import matplotlib.pyplot as plt\n", |
| "import numpy as np\n", |
| "\n", |
| "from prepare_dataset import process_dataset #utility code to rearrange the data\n", |
| "\n", |
| "mx.random.seed(42)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "3434543b", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Download dataset\n", |
| "url = 'https://md-datasets-cache-zipfiles-prod.s3.eu-west-1.amazonaws.com/hb74ynkjcn-1.zip'\n", |
| "zip_file_path = mx.gluon.utils.download(url)\n", |
| "\n", |
| "os.makedirs('plants', exist_ok=True)\n", |
| "\n", |
| "with zipfile.ZipFile(zip_file_path, 'r') as zf:\n", |
| " zf.extractall('plants')\n", |
| "\n", |
| "os.remove(zip_file_path)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "1fde2966", |
| "metadata": {}, |
| "source": [ |
| "#### Data inspection\n", |
| "\n", |
| "If you take a look at the dataset you find the following structure for the directories:" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "f074d4e1", |
| "metadata": {}, |
| "source": [ |
| "```\n", |
| "plants\n", |
| "|-- Alstonia Scholaris (P2)\n", |
| "|-- Arjun (P1)\n", |
| "|-- Bael (P4)\n", |
| " |-- diseased\n", |
| " |-- 0016_0001.JPG\n", |
| " |-- .\n", |
| " |-- .\n", |
| " |-- .\n", |
| " |-- 0016_0118.JPG\n", |
| "|-- .\n", |
| "|-- .\n", |
| "|-- .\n", |
| "|-- Mango (P0)\n", |
| " |-- diseased\n", |
| " |-- healthy\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "b8364857", |
| "metadata": {}, |
| "source": [ |
| "Each plant species has its own directory, for each of those directories you might\n", |
| "find subdirectories with examples of diseased leaves, healthy\n", |
| "leaves, or both. With this dataset you can formulate different classification\n", |
| "problems; for example, you can create a multi-class classifier that determines\n", |
| "the species of a plant based on the leaves; you can instead create a binary\n", |
| "classifier that tells you whether the plant is healthy or diseased. Additionally, you can create\n", |
| "a multi-class, multi-label classifier that tells you both: what species a\n", |
| "plant is and whether the plant is diseased or healthy. In this example you will stick to\n", |
| "the simplest classification question, which is whether a plant is healthy or not.\n", |
| "\n", |
| "To do this, you need to manipulate the dataset in two ways. First, you need to\n", |
| "combine all images with labels consisting of healthy and diseased, regardless of the species, and then you\n", |
| "need to split the data into train, validation, and test sets. We prepared a\n", |
| "small utility script that does this to get the dataset ready for you.\n", |
| "Once you run this utility code on the data, the structure will be\n", |
| "already organized in folders containing the right images in each of the classes,\n", |
| "you can use the `ImageFolderDataset` class to import the images from the file to MXNet." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "8af6ab9d", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Call the utility function to rearrange the images\n", |
| "process_dataset('plants')" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "ec6eec3a", |
| "metadata": {}, |
| "source": [ |
| "The dataset is located in the `datasets` folder and the new structure\n", |
| "looks like this:" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "f01413a8", |
| "metadata": {}, |
| "source": [ |
| "```\n", |
| "datasets\n", |
| "|-- test\n", |
| " |-- diseased\n", |
| " |-- healthy\n", |
| "|-- train\n", |
| "|-- validation\n", |
| " |-- diseased\n", |
| " |-- healthy\n", |
| " |-- image1.JPG\n", |
| " |-- image2.JPG\n", |
| " |-- .\n", |
| " |-- .\n", |
| " |-- .\n", |
| " |-- imagen.JPG\n", |
| "```\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "a3abbe80", |
| "metadata": {}, |
| "source": [ |
| "Now, you need to create three different Dataset objects from the `train`,\n", |
| "`validation`, and `test` folders, and the `ImageFolderDataset` class takes\n", |
| "care of inferring the classes from the directory names. If you don't remember\n", |
| "how the `ImageFolderDataset` works, take a look at [Step 5](5-datasets.md)\n", |
| "of this course for a deeper description." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "24a982b8", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Use ImageFolderDataset to create a Dataset object from directory structure\n", |
| "train_dataset = gluon.data.vision.ImageFolderDataset('./datasets/train')\n", |
| "val_dataset = gluon.data.vision.ImageFolderDataset('./datasets/validation')\n", |
| "test_dataset = gluon.data.vision.ImageFolderDataset('./datasets/test')" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "f98f8bd8", |
| "metadata": {}, |
| "source": [ |
| "The result from this operation is a different Dataset object for each folder.\n", |
| "These objects hold a collection of images and labels and as such they can be\n", |
| "indexed, to get the $i$-th element from the dataset. The $i$-th element is a\n", |
| "tuple with two objects, the first object of the tuple is the image in array\n", |
| "form and the second is the corresponding label for that image." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "7ad108ed", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "sample_idx = 888 # choose a random sample\n", |
| "sample = train_dataset[sample_idx]\n", |
| "data = sample[0]\n", |
| "label = sample[1]\n", |
| "\n", |
| "plt.imshow(data.asnumpy())\n", |
| "print(f\"Data type: {data.dtype}\")\n", |
| "print(f\"Label: {label}\")\n", |
| "print(f\"Label description: {train_dataset.synsets[label]}\")\n", |
| "print(f\"Image shape: {data.shape}\")" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "2c8df173", |
| "metadata": {}, |
| "source": [ |
| "As you can see from the plot, the image size is very large 4000 x 6000 pixels.\n", |
| "Usually, you downsize images before passing them to a neural network to reduce the training time.\n", |
| "It is also customary to make slight modifications to the images to improve generalization. That is why you add\n", |
| "transformations to the data in a process called Data Augmentation.\n", |
| "\n", |
| "You can augment data in MXNet using `transforms`. For a complete list of all\n", |
| "the available transformations in MXNet check out\n", |
| "[available transforms](../../../api/gluon/data/vision/transforms/index.rst).\n", |
| "It is very common to use more than one transform per image, and it is also\n", |
| "common to process transforms sequentially. To this end, you can use the `transforms.Compose` class.\n", |
| "This class is very useful to create a transformation pipeline for your images.\n", |
| "\n", |
| "You have to compose two different transformation pipelines, one for training\n", |
| "and the other one for validating and testing. This is because each pipeline\n", |
| "serves different pursposes. You need to downsize, convert to tensor and normalize\n", |
| "images across all the different datsets; however, you typically do not want to randomly flip\n", |
| "or add color jitter to the validation or test images since you could reduce performance." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "45a84b5c", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Import transforms as compose a series of transformations to the images\n", |
| "from mxnet.gluon.data.vision import transforms\n", |
| "\n", |
| "jitter_param = 0.05\n", |
| "\n", |
| "# mean and std for normalizing image value in range (0,1)\n", |
| "mean = [0.485, 0.456, 0.406]\n", |
| "std = [0.229, 0.224, 0.225]\n", |
| "\n", |
| "training_transformer = transforms.Compose([\n", |
| " transforms.Resize(size=224, keep_ratio=True),\n", |
| " transforms.CenterCrop(128),\n", |
| " transforms.RandomFlipLeftRight(),\n", |
| " transforms.RandomColorJitter(contrast=jitter_param),\n", |
| " transforms.ToTensor(),\n", |
| " transforms.Normalize(mean, std)\n", |
| "])\n", |
| "\n", |
| "validation_transformer = transforms.Compose([\n", |
| " transforms.Resize(size=224, keep_ratio=True),\n", |
| " transforms.CenterCrop(128),\n", |
| " transforms.ToTensor(),\n", |
| " transforms.Normalize(mean, std)\n", |
| "])" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "88809d0b", |
| "metadata": {}, |
| "source": [ |
| "With your augmentations ready, you can create the `DataLoaders` to use them. To\n", |
| "do this the `gluon.data.DataLoader` class comes in handy. You have to pass the dataset with\n", |
| "the applied transformations (notice the `.transform_first()` method on the datasets)\n", |
| "to `gluon.data.DataLoader`. Additionally, you need to decide the batch size,\n", |
| "which is how many images you will be passing to the network,\n", |
| "and whether you want to shuffle the dataset." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "602bd977", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Create data loaders\n", |
| "batch_size = 4\n", |
| "train_loader = gluon.data.DataLoader(train_dataset.transform_first(training_transformer),\n", |
| " batch_size=batch_size,\n", |
| " shuffle=True,\n", |
| " try_nopython=True)\n", |
| "validation_loader = gluon.data.DataLoader(val_dataset.transform_first(validation_transformer),\n", |
| " batch_size=batch_size,\n", |
| " try_nopython=True)\n", |
| "test_loader = gluon.data.DataLoader(test_dataset.transform_first(validation_transformer),\n", |
| " batch_size=batch_size,\n", |
| " try_nopython=True)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "f6231c94", |
| "metadata": {}, |
| "source": [ |
| "Now, you can inspect the transformations that you made to the images. A prepared\n", |
| "utility function has been provided for this." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "21083577", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Function to plot batch\n", |
| "def show_batch(batch, columns=4, fig_size=(9, 5), pad=1):\n", |
| " labels = batch[1].asnumpy()\n", |
| " batch = batch[0] / 2 + 0.5 # unnormalize\n", |
| " batch = np.clip(batch.asnumpy(), 0, 1) # clip values\n", |
| " size = batch.shape[0]\n", |
| " rows = int(size / columns)\n", |
| " fig, axes = plt.subplots(rows, columns, figsize=fig_size)\n", |
| " for ax, img, label in zip(axes.flatten(), batch, labels):\n", |
| " ax.imshow(np.transpose(img, (1, 2, 0)))\n", |
| " ax.set(title=f\"Label: {label}\")\n", |
| " fig.tight_layout(h_pad=pad, w_pad=pad)\n", |
| " plt.show()" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "81a83344", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "for batch in train_loader:\n", |
| " a = batch\n", |
| " break" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "f681d5e7", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "show_batch(a)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "dfa9fb4a", |
| "metadata": {}, |
| "source": [ |
| "You can see that the original images changed to have different sizes and variations\n", |
| "in color and lighting. These changes followed the specified transformations you stated\n", |
| "in the pipeline. You are now ready to go to the next step: **Create the\n", |
| "architecture**.\n", |
| "\n", |
| "## 2. Create Neural Network\n", |
| "\n", |
| "Convolutional neural networks are a great tool to capture the spatial\n", |
| "relationship of pixel values within images, for this reason they have become the\n", |
| "gold standard for computer vision. In this example you will create a small convolutional neural\n", |
| "network using what you learned from [Step 2](2-create-nn.md) of this crash course series.\n", |
| "First, you can set up two functions that will generate the two types of blocks\n", |
| "you intend to use, the convolution block and the dense block. Then you can create an\n", |
| "entire network based on these two blocks using a custom class." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "b4ff0029", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# The convolutional block has a convolution layer, a max pool layer and a batch normalization layer\n", |
| "def conv_block(filters, kernel_size=2, stride=2, batch_norm=True):\n", |
| " conv_block = nn.HybridSequential()\n", |
| " conv_block.add(nn.Conv2D(channels=filters, kernel_size=kernel_size, activation='relu'),\n", |
| " nn.MaxPool2D(pool_size=4, strides=stride))\n", |
| " if batch_norm:\n", |
| " conv_block.add(nn.BatchNorm())\n", |
| " return conv_block\n", |
| "\n", |
| "# The dense block consists of a dense layer and a dropout layer\n", |
| "def dense_block(neurons, activation='relu', dropout=0.2):\n", |
| " dense_block = nn.HybridSequential()\n", |
| " dense_block.add(nn.Dense(neurons, activation=activation))\n", |
| " if dropout:\n", |
| " dense_block.add(nn.Dropout(dropout))\n", |
| " return dense_block" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "15e92664", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Create neural network blueprint using the blocks\n", |
| "class LeafNetwork(nn.HybridBlock):\n", |
| " def __init__(self):\n", |
| " super(LeafNetwork, self).__init__()\n", |
| " self.conv1 = conv_block(32)\n", |
| " self.conv2 = conv_block(64)\n", |
| " self.conv3 = conv_block(128)\n", |
| " self.flatten = nn.Flatten()\n", |
| " self.dense1 = dense_block(100)\n", |
| " self.dense2 = dense_block(10)\n", |
| " self.dense3 = nn.Dense(2)\n", |
| "\n", |
| " def forward(self, batch):\n", |
| " batch = self.conv1(batch)\n", |
| " batch = self.conv2(batch)\n", |
| " batch = self.conv3(batch)\n", |
| " batch = self.flatten(batch)\n", |
| " batch = self.dense1(batch)\n", |
| " batch = self.dense2(batch)\n", |
| " batch = self.dense3(batch)\n", |
| "\n", |
| " return batch" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "edf5429f", |
| "metadata": {}, |
| "source": [ |
| "You have concluded the architecting part of the network, so now you can actually\n", |
| "build a model from that architecture for training. As you have seen\n", |
| "previously on [Step 4](4-components.md) of this\n", |
| "crash course series, to use the network you need to initialize the parameters and\n", |
| "hybridize the model." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "d374c0cb", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Create the model based on the blueprint provided and initialize the parameters\n", |
| "ctx = mx.cpu()\n", |
| "\n", |
| "initializer = mx.initializer.Xavier()\n", |
| "\n", |
| "model = LeafNetwork()\n", |
| "model.initialize(initializer, ctx=ctx)\n", |
| "model.summary(mx.nd.random.uniform(shape=(4, 3, 128, 128)))\n", |
| "model.hybridize()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "648ce7fd", |
| "metadata": {}, |
| "source": [ |
| "## 3. Choose Optimizer and Loss function\n", |
| "\n", |
| "With the network created you can move on to choosing an optimizer and a loss\n", |
| "function. The network you created uses these components to make an informed decision on how\n", |
| "to tune the parameters to fit the final objective better. You can use the `gluon.Trainer` class to\n", |
| "help with optimizing these parameters. The `gluon.Trainer` class needs two things to work\n", |
| "properly: the parameters needing to be tuned and the optimizer with its\n", |
| "corresponding hyperparameters. The trainer uses the error reported by the loss\n", |
| "function to optimize these parameters.\n", |
| "\n", |
| "For this particular dataset you will use Stochastic Gradient Descent as the\n", |
| "optimizer and Cross Entropy as the loss function." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "546e23f8", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# SGD optimizer\n", |
| "optimizer = 'sgd'\n", |
| "\n", |
| "# Set parameters\n", |
| "optimizer_params = {'learning_rate': 0.001}\n", |
| "\n", |
| "# Define the trainer for the model\n", |
| "trainer = gluon.Trainer(model.collect_params(), optimizer, optimizer_params)\n", |
| "\n", |
| "# Define the loss function\n", |
| "loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "5100d019", |
| "metadata": {}, |
| "source": [ |
| "Finally, you have to set up the training loop, and you need to create a function to evaluate the performance of the network on the validation dataset." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "fd9f6155", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Function to return the accuracy for the validation and test set\n", |
| "def test(val_data):\n", |
| " acc = gluon.metric.Accuracy()\n", |
| " for batch in val_data:\n", |
| " data = batch[0]\n", |
| " labels = batch[1]\n", |
| " outputs = model(data)\n", |
| " acc.update([labels], [outputs])\n", |
| "\n", |
| " _, accuracy = acc.get()\n", |
| " return accuracy" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "44be0b78", |
| "metadata": {}, |
| "source": [ |
| "## 4. Training Loop\n", |
| "\n", |
| "Now that you have everything set up, you can start training your network. This might\n", |
| "take some time to train depending on the hardware, number of layers, batch size and\n", |
| "images you use. For this particular case, you will only train for 2 epochs." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "79973b1b", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Start the training loop\n", |
| "epochs = 2\n", |
| "accuracy = gluon.metric.Accuracy()\n", |
| "log_interval = 5\n", |
| "\n", |
| "for epoch in range(epochs):\n", |
| " tic = time.time()\n", |
| " btic = time.time()\n", |
| " accuracy.reset()\n", |
| "\n", |
| " for idx, batch in enumerate(train_loader):\n", |
| " data = batch[0]\n", |
| " label = batch[1]\n", |
| " with mx.autograd.record():\n", |
| " outputs = model(data)\n", |
| " loss = loss_fn(outputs, label)\n", |
| " mx.autograd.backward(loss)\n", |
| " trainer.step(batch_size)\n", |
| " accuracy.update([label], [outputs])\n", |
| " if log_interval and (idx + 1) % log_interval == 0:\n", |
| " _, acc = accuracy.get()\n", |
| "\n", |
| " print(f\"\"\"Epoch[{epoch + 1}] Batch[{idx + 1}] Speed: {batch_size / (time.time() - btic)} samples/sec \\\n", |
| " batch loss = {loss.mean().asscalar()} | accuracy = {acc}\"\"\")\n", |
| " btic = time.time()\n", |
| "\n", |
| " _, acc = accuracy.get()\n", |
| "\n", |
| " acc_val = test(validation_loader)\n", |
| " print(f\"[Epoch {epoch + 1}] training: accuracy={acc}\")\n", |
| " print(f\"[Epoch {epoch + 1}] time cost: {time.time() - tic}\")\n", |
| " print(f\"[Epoch {epoch + 1}] validation: validation accuracy={acc_val}\")" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "473cd2e0", |
| "metadata": {}, |
| "source": [ |
| "## 5. Test on the test set\n", |
| "\n", |
| "Now that your network is trained and has reached a decent accuracy, you can\n", |
| "evaluate the performance on the test set. For that, you can use the `test_loader` data\n", |
| "loader and the test function you created previously." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "25f45637", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "test(test_loader)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "cbed9eaf", |
| "metadata": {}, |
| "source": [ |
| "You have a trained network that can confidently discriminate between plants that\n", |
| "are healthy and the ones that are diseased. You can now start your garden and\n", |
| "set cameras to automatically detect plants in distress! Or change your classification\n", |
| "problem to create a model that classify the species of the plants! Either way you\n", |
| "might be able to impress your botanist friends.\n", |
| "\n", |
| "## 6. Save the parameters\n", |
| "\n", |
| "If you want to preserve the trained weights of the network you can save the\n", |
| "parameters in a file. Later, when you want to use the network to make predictions\n", |
| "you can load the parameters back!" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "id": "0b9659cc", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Save parameters in the\n", |
| "model.save_parameters('leaf_models.params')" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "1392cce3", |
| "metadata": {}, |
| "source": [ |
| "This is the end of this tutorial, to see how you can speed up the training by\n", |
| "using GPU hardware continue to the [next tutorial](./7-use-gpus.ipynb)" |
| ] |
| } |
| ], |
| "metadata": { |
| "language_info": { |
| "name": "python" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |