| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "id": "92af9af0", |
| "metadata": {}, |
| "source": [ |
| "<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n", |
| "<!--- or more contributor license agreements. See the NOTICE file -->\n", |
| "<!--- distributed with this work for additional information -->\n", |
| "<!--- regarding copyright ownership. The ASF licenses this file -->\n", |
| "<!--- to you under the Apache License, Version 2.0 (the -->\n", |
| "<!--- \"License\"); you may not use this file except in compliance -->\n", |
| "<!--- with the License. You may obtain a copy of the License at -->\n", |
| "\n", |
| "<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n", |
| "\n", |
| "<!--- Unless required by applicable law or agreed to in writing, -->\n", |
| "<!--- software distributed under the License is distributed on an -->\n", |
| "<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n", |
| "<!--- KIND, either express or implied. See the License for the -->\n", |
| "<!--- specific language governing permissions and limitations -->\n", |
| "<!--- under the License. -->\n", |
| "\n", |
| "# Step 4: Necessary components that are not in the network\n", |
| "\n", |
| "Data and models are not the only components that\n", |
| "you need to train a deep learning model. In this notebook, you will\n", |
| "learn about the common components involved in training deep learning models. \n", |
| "Here is a list of components necessary for training models in MXNet.\n", |
| "\n", |
| "1. Initialization\n", |
| "2. Loss functions\n", |
| " 1. Built-in\n", |
| " 2. Custom\n", |
| "3. Optimizers\n", |
| "4. Metrics" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 1, |
| "id": "488d67ce", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from mxnet import np, npx,gluon\n", |
| "import mxnet as mx\n", |
| "from mxnet.gluon import nn\n", |
| "npx.set_np()\n", |
| "\n", |
| "device = mx.cpu()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "2b79dfff", |
| "metadata": {}, |
| "source": [ |
| "## Initialization\n", |
| "\n", |
| "In a previous notebook, you used `net.initialize()` to initialize the network\n", |
| "before a forward pass. Now, you will learn about initialization in a little more\n", |
| "detail.\n", |
| "\n", |
| "First, define and initialize the `sequential` network from earlier.\n", |
| "After you initialize it, print the parameters using `collect_params()` method." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 2, |
| "id": "814e2583", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "text/plain": [ |
| "Sequential(\n", |
| " (0): Dense(3 -> 5, Activation(relu))\n", |
| " (1): Dense(-1 -> 25, Activation(relu))\n", |
| " (2): Dense(-1 -> 2, linear)\n", |
| ")" |
| ] |
| }, |
| "execution_count": 2, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "net = nn.Sequential()\n", |
| "\n", |
| "net.add(nn.Dense(5, in_units=3, activation=\"relu\"),\n", |
| " nn.Dense(25, activation=\"relu\"),\n", |
| " nn.Dense(2)\n", |
| " )\n", |
| "\n", |
| "net" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "id": "e886dfff", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "0.weight Parameter (shape=(5, 3), dtype=float32)\n", |
| "0.bias Parameter (shape=(5,), dtype=float32)\n", |
| "1.weight Parameter (shape=(25, -1), dtype=float32)\n", |
| "1.bias Parameter (shape=(25,), dtype=float32)\n", |
| "2.weight Parameter (shape=(2, -1), dtype=float32)\n", |
| "2.bias Parameter (shape=(2,), dtype=float32)\n" |
| ] |
| }, |
| { |
| "name": "stderr", |
| "output_type": "stream", |
| "text": [ |
| "[03:52:10] /work/mxnet/src/storage/storage.cc:202: Using Pooled (Naive) StorageManager for CPU\n" |
| ] |
| } |
| ], |
| "source": [ |
| "net.initialize()\n", |
| "params = net.collect_params()\n", |
| "\n", |
| "for key, value in params.items():\n", |
| " print(key, value)\n", |
| "\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "63a429c2", |
| "metadata": {}, |
| "source": [ |
| "Next, you will print shape and params after the first forward pass." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "id": "8bcd843b", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "0.weight Parameter (shape=(5, 3), dtype=float32)\n", |
| "0.bias Parameter (shape=(5,), dtype=float32)\n", |
| "1.weight Parameter (shape=(25, 5), dtype=float32)\n", |
| "1.bias Parameter (shape=(25,), dtype=float32)\n", |
| "2.weight Parameter (shape=(2, 25), dtype=float32)\n", |
| "2.bias Parameter (shape=(2,), dtype=float32)\n" |
| ] |
| } |
| ], |
| "source": [ |
| "x = np.random.uniform(-1, 1, (10, 3))\n", |
| "net(x) # Forward computation\n", |
| "\n", |
| "params = net.collect_params()\n", |
| "for key, value in params.items():\n", |
| " print(key, value)\n", |
| "\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "3e6611ec", |
| "metadata": {}, |
| "source": [ |
| "#### Built-in Initialization\n", |
| "\n", |
| "MXNet makes it easy to initialize by providing many common initializers. A subset that you will be using in the following sections include:\n", |
| "\n", |
| "- Constant\n", |
| "- Normal\n", |
| "\n", |
| "For more information, see\n", |
| "[Initializers](../../../api/initializer/index.rst)\n", |
| "\n", |
| "When you use `net.intialize()`, MXNet, by default, initializes the weight matrices uniformly\n", |
| "by drawing random values with a uniform-distribution between −0.07 and 0.07 and\n", |
| "updates the bias parameters by setting them all to 0.\n", |
| "\n", |
| "To initialize your network using different built-in types, you have to use the\n", |
| "`init` keyword argument in the `initialize()` method. Here is an example using\n", |
| "`constant` and `normal` initialization." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "id": "12ed3c1a", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "[ 0.0110525 -0.01441184 -0.05202791]\n" |
| ] |
| }, |
| { |
| "name": "stderr", |
| "output_type": "stream", |
| "text": [ |
| "/work/mxnet/python/mxnet/util.py:755: UserWarning: Parameter 'weight' is already initialized, ignoring. Set force_reinit=True to re-initialize.\n", |
| " return func(*args, **kwargs)\n", |
| "/work/mxnet/python/mxnet/util.py:755: UserWarning: Parameter 'bias' is already initialized, ignoring. Set force_reinit=True to re-initialize.\n", |
| " return func(*args, **kwargs)\n" |
| ] |
| } |
| ], |
| "source": [ |
| "from mxnet import init\n", |
| "\n", |
| "# Constant init initializes the weights to be a constant value for all the params\n", |
| "net.initialize(init=init.Constant(3), device=device)\n", |
| "print(net[0].weight.data()[0])" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "997c332d", |
| "metadata": {}, |
| "source": [ |
| "If you use Normal to initialize your weights then you will use a normal\n", |
| "distribution with a mean of zero and standard deviation of sigma. If you have\n", |
| "already initialized the weight but want to reinitialize the weight, set the\n", |
| "`force_reinit` flag to `True`." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "id": "07e00fb6", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "[0.04093673 0.04288622 0.44828448]\n" |
| ] |
| } |
| ], |
| "source": [ |
| "net.initialize(init=init.Normal(sigma=0.2), force_reinit=True, device=device)\n", |
| "print(net[0].weight.data()[0])" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "839cb2fb", |
| "metadata": {}, |
| "source": [ |
| "## Components used in a training loop\n", |
| "\n", |
| "Till now you have seen how to create an algorithm and how to initialize it using mxnet\n", |
| "APIs; additionally you have learned the basics of using mxnet. When you start training the\n", |
| "ML algorithm, how do you actually teach the algorithm to learn or train?\n", |
| "\n", |
| "There are three main components for training an algorithm.\n", |
| "\n", |
| "1. Loss function: calculates how far the model is from the true distribution\n", |
| "2. Autograd: the mxnet auto differentiation tool that calculates the gradients to\n", |
| "optimize the parameters\n", |
| "3. Optimizer: updates the parameters based on an optimization algorithm\n", |
| "\n", |
| "You have already learned about autograd in the previous notebook. In this\n", |
| "notebook, you will learn more about loss functions and optimizers.\n", |
| "\n", |
| "## Loss function\n", |
| "\n", |
| "Loss functions are used to train neural networks and help the algorithm learn\n", |
| "from the data. The loss function computes the difference between the\n", |
| "output from the neural network and ground truth. This output is used to\n", |
| "update the neural network weights during training. Next, you will look at a\n", |
| "simple example.\n", |
| "\n", |
| "Suppose you have a neural network `net` and the data is stored in a variable\n", |
| "`data`. The data consists of 5 total records (rows) and two features (columns)\n", |
| "and the output from the neural network after the first epoch is given by the\n", |
| "variable `nn_output`." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 7, |
| "id": "498ef8c0", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "text/plain": [ |
| "array([[-0.0928179 ],\n", |
| " [-0.21960056],\n", |
| " [-0.07770944],\n", |
| " [-0.10332978],\n", |
| " [-0.07763617]])" |
| ] |
| }, |
| "execution_count": 7, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "net = gluon.nn.Dense(1)\n", |
| "net.initialize()\n", |
| "\n", |
| "nn_input = np.array([[1.2, 0.56],\n", |
| " [3.0, 0.72],\n", |
| " [0.89, 0.9],\n", |
| " [0.89, 2.3],\n", |
| " [0.99, 0.52]])\n", |
| "\n", |
| "nn_output = net(nn_input)\n", |
| "nn_output" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "3a803c95", |
| "metadata": {}, |
| "source": [ |
| "The ground truth value of the data is stored in `groundtruth_label` is" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 8, |
| "id": "5ac98831", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "groundtruth_label = np.array([[0.0083],\n", |
| " [0.00382],\n", |
| " [0.02061],\n", |
| " [0.00495],\n", |
| " [0.00639]]).reshape(5, 1)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "dca140dc", |
| "metadata": {}, |
| "source": [ |
| "For this problem, you will use the L2 Loss. L2Loss, also called Mean Squared Error, is a\n", |
| "regression loss function that computes the squared distances between the target\n", |
| "values and the output of the neural network. It is defined as:\n", |
| "\n", |
| "$$L = \\frac{1}{2N}\\sum_i{|label_i − pred_i|)^2}$$\n", |
| "\n", |
| "The L2 loss function creates larger gradients for loss values which are farther apart due to the\n", |
| "square operator and it also smooths the loss function space." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 9, |
| "id": "571d764e", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "text/plain": [ |
| "array([0.00511241, 0.02495837, 0.00483336, 0.00586226, 0.0035302 ])" |
| ] |
| }, |
| "execution_count": 9, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "def L2Loss(output_values, true_values):\n", |
| " return np.mean((output_values - true_values) ** 2, axis=1) / 2\n", |
| "\n", |
| "L2Loss(nn_output, groundtruth_label)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "5734797f", |
| "metadata": {}, |
| "source": [ |
| "Now, you can do the same thing using the mxnet API" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 10, |
| "id": "962d8446", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "text/plain": [ |
| "array([0.00511241, 0.02495837, 0.00483336, 0.00586226, 0.0035302 ])" |
| ] |
| }, |
| "execution_count": 10, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "from mxnet.gluon import nn, loss as gloss\n", |
| "loss = gloss.L2Loss()\n", |
| "\n", |
| "loss(nn_output, groundtruth_label)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "bbd1d98a", |
| "metadata": {}, |
| "source": [ |
| "A network can improve by iteratively updating its weights to minimise the loss.\n", |
| "Some tasks use a combination of multiple loss functions, but often you will just\n", |
| "use one. MXNet Gluon provides a number of the most commonly used loss functions.\n", |
| "The choice of your loss function will depend on your network and task. Some\n", |
| "common tasks and loss function pairs include:\n", |
| "\n", |
| "- regression: L1Loss, L2Loss\n", |
| "\n", |
| "- classification: SigmoidBinaryCrossEntropyLoss, SoftmaxCrossEntropyLoss\n", |
| "\n", |
| "- embeddings: HingeLoss\n", |
| "\n", |
| "#### Customizing your Loss functions\n", |
| "\n", |
| "You can also create custom loss functions using **Loss Blocks**.\n", |
| "\n", |
| "You can inherit the base `Loss` class and write your own `forward` method. The\n", |
| "backward propagation will be automatically computed by autograd. However, that\n", |
| "only holds true if you can build your loss from existing mxnet operators." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 11, |
| "id": "2e6d6b64", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "text/plain": [ |
| "array([0.10111789, 0.22342056, 0.09831944, 0.10827978, 0.08402617])" |
| ] |
| }, |
| "execution_count": 11, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "from mxnet.gluon.loss import Loss\n", |
| "\n", |
| "class custom_L1_loss(Loss):\n", |
| " def __init__(self, weight=None, batch_axis=0, **kwargs):\n", |
| " super(custom_L1_loss, self).__init__(weight, batch_axis, **kwargs)\n", |
| "\n", |
| " def forward(self, pred, label):\n", |
| " l = np.abs(label - pred)\n", |
| " l = l.reshape(len(l),)\n", |
| " return l\n", |
| " \n", |
| "L1 = custom_L1_loss()\n", |
| "L1(nn_output, groundtruth_label)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 12, |
| "id": "ad266bc3", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "text/plain": [ |
| "array([0.10111789, 0.22342056, 0.09831944, 0.10827978, 0.08402617])" |
| ] |
| }, |
| "execution_count": 12, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "l1=gloss.L1Loss()\n", |
| "l1(nn_output, groundtruth_label)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "c04e72d4", |
| "metadata": {}, |
| "source": [ |
| "## Optimizer\n", |
| "\n", |
| "The loss function determines how much to change the parameters based on how far the\n", |
| "model is from the groundtruth. Optimizer determines how the model\n", |
| "weights or parameters are updated based on the loss function. In Gluon, this\n", |
| "optimization step is performed by the `gluon.Trainer`.\n", |
| "\n", |
| "Here is a basic example of how to call the `gluon.Trainer` method." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 13, |
| "id": "3d10935b", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from mxnet import optimizer" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 14, |
| "id": "2f57760d", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "trainer = gluon.Trainer(net.collect_params(),\n", |
| " optimizer=\"Adam\",\n", |
| " optimizer_params={\n", |
| " \"learning_rate\":0.1,\n", |
| " \"wd\":0.001\n", |
| " })" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "eaec1160", |
| "metadata": {}, |
| "source": [ |
| "When creating a **Gluon Trainer**, you must provide the trainer object with\n", |
| "1. A collection of parameters that need to be learnt. The collection of\n", |
| "parameters will be the weights and biases of your network that you are training.\n", |
| "2. An Optimization algorithm (optimizer) that you want to use for training. This\n", |
| "algorithm will be used to update the parameters every training iteration when\n", |
| "`trainer.step` is called. For more information, see\n", |
| "[optimizers](../../../api/optimizer/index.rst)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 15, |
| "id": "cd3e10ff", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "[[-0.06880813 -0.01830024]]\n" |
| ] |
| } |
| ], |
| "source": [ |
| "curr_weight = net.weight.data()\n", |
| "print(curr_weight)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 16, |
| "id": "dc2ccb6b", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "[[-0.06880813 -0.01830024]]\n" |
| ] |
| } |
| ], |
| "source": [ |
| "batch_size = len(nn_input)\n", |
| "trainer.step(batch_size, ignore_stale_grad=True)\n", |
| "print(net.weight.data())" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 17, |
| "id": "c588c6f9", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "[[-0.06880813 -0.01830024]]\n" |
| ] |
| } |
| ], |
| "source": [ |
| "print(curr_weight - net.weight.grad() * 1 / 5)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "a509a281", |
| "metadata": {}, |
| "source": [ |
| "## Metrics\n", |
| "\n", |
| "MXNet includes a `metrics` API that you can use to evaluate how your model is\n", |
| "performing. This is typically used during training to monitor performance on the\n", |
| "validation set. MXNet includes many commonly used metrics, a few are listed below:\n", |
| "\n", |
| "- [Accuracy](../../../api/gluon/metric/index.rst#mxnet.gluon.metric.Accuracy)\n", |
| "- [CrossEntropy](../../../api/gluon/metric/index.rst#mxnet.gluon.metric.CrossEntropy)\n", |
| "- [Mean squared error](../../../api/gluon/metric/index.rst#mxnet.gluon.metric.MSE)\n", |
| "- [Root mean squared error (RMSE)](../../../api/gluon/metric/index.rst#mxnet.gluon.metric.RMSE)\n", |
| "\n", |
| "Now, you will define two arrays for a dummy binary classification example." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 18, |
| "id": "a4626cf5", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "# Vector of likelihoods for all the classes\n", |
| "pred = np.array([[0.1, 0.9], [0.05, 0.95], [0.83, 0.17], [0.63, 0.37]])\n", |
| "\n", |
| "labels = np.array([1, 1, 0, 1])" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "7403f1a6", |
| "metadata": {}, |
| "source": [ |
| "Before you can calculate the accuracy of your model, the metric (accuracy)\n", |
| "should be instantiated before the training loop" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 19, |
| "id": "37e7c5c8", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from mxnet.gluon.metric import Accuracy\n", |
| "\n", |
| "acc = Accuracy()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "b65ad675", |
| "metadata": {}, |
| "source": [ |
| "To run and calculate the updated accuracy for each batch or epoch, you can call\n", |
| "the `update()` method. This method uses labels and predictions which can be\n", |
| "either class indexes or a vector of likelihoods for all of the classes." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 20, |
| "id": "5fcd316b", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "acc.update(labels=labels, preds=pred)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "6a7963a8", |
| "metadata": {}, |
| "source": [ |
| "#### Creating custom metrics\n", |
| "\n", |
| "In addition to built-in metrics, if you want to create a custom metric, you can\n", |
| "use the following skeleton code. This code inherits from the `EvalMetric` base\n", |
| "class." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 21, |
| "id": "84517264", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def MyCustomMetric(EvalMetric):\n", |
| " def __init__(self):\n", |
| " super().init()\n", |
| "\n", |
| " def update(self, labels, preds):\n", |
| " pass\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "a9b9eee7", |
| "metadata": {}, |
| "source": [ |
| "Here is an example using the Precision metric. First, define the two values\n", |
| "`labels` and `preds`." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 22, |
| "id": "56c97528", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "labels = np.array([0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1])\n", |
| "preds = np.array([0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0])" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "d433b134", |
| "metadata": {}, |
| "source": [ |
| "Next, define the custom metric class `precision` and instantiate it" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 23, |
| "id": "6e544142", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "from mxnet.gluon.metric import EvalMetric\n", |
| "\n", |
| "class precision(EvalMetric):\n", |
| " def __init__(self):\n", |
| " super().__init__(name=\"Precision\")\n", |
| " \n", |
| " def update(self,labels, preds):\n", |
| " tp_labels = (labels == 1)\n", |
| " true_positives = sum(preds[tp_labels] == 1)\n", |
| " fp_labels = (labels == 0)\n", |
| " false_positives = sum(preds[fp_labels] == 1)\n", |
| " return true_positives / (true_positives + false_positives)\n", |
| " \n", |
| "p = precision()" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "bb38ae9c", |
| "metadata": {}, |
| "source": [ |
| "And finally, call the `update` method to return the results of `precision` for your data" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 24, |
| "id": "aa4a8253", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "data": { |
| "text/plain": [ |
| "array(0.4)" |
| ] |
| }, |
| "execution_count": 24, |
| "metadata": {}, |
| "output_type": "execute_result" |
| } |
| ], |
| "source": [ |
| "p.update(np.array(labels), np.array(preds))" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "284830de", |
| "metadata": {}, |
| "source": [ |
| "## Next steps\n", |
| "\n", |
| "Now that you have learned all the components required to train a neural network,\n", |
| "you will see how to load your data using the Gluon API in [Step 5: Gluon\n", |
| "Datasets and DataLoader](./5-datasets.ipynb)" |
| ] |
| } |
| ], |
| "metadata": { |
| "language_info": { |
| "name": "python" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |