blob: 707353a0b7b4d2af58d8dab92908426c405187ac [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "130b443b",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"# Loss functions\n",
"\n",
"Loss functions are used to train neural networks and to compute the difference between output and target variable. A critical component of training neural networks is the loss function. A loss function is a quantative measure of how bad the predictions of the network are when compared to ground truth labels. Given this score, a network can improve by iteratively updating its weights to minimise this loss. Some tasks use a combination of multiple loss functions, but often you'll just use one. MXNet Gluon provides a number of the most commonly used loss functions, and you'll choose certain loss functions depending on your network and task. Some common task and loss function pairs include:\n",
"\n",
"- Regression: [L1Loss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.L1Loss), [L2Loss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.L2Loss)\n",
"- Classification: [SigmoidBinaryCrossEntropyLoss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.SigmoidBinaryCrossEntropyLoss), [SoftmaxCrossEntropyLoss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.SoftmaxCrossEntropyLoss)\n",
"- Embeddings: [HingeLoss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.HingeLoss)\n",
"\n",
"We'll first import the modules, where the `mxnet.gluon.loss` module is imported as `gloss` to avoid the commonly used name `loss`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c22dda1a",
"metadata": {},
"outputs": [],
"source": [
"from IPython import display\n",
"from matplotlib import pyplot as plt\n",
"import mxnet as mx\n",
"from mxnet import np, npx, autograd\n",
"from mxnet.gluon import nn, loss as gloss"
]
},
{
"cell_type": "markdown",
"id": "093c16e4",
"metadata": {},
"source": [
"## Basic Usages\n",
"\n",
"Now let's create an instance of the $\\ell_2$ loss, which is commonly used in regression tasks."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "5573d646",
"metadata": {},
"outputs": [],
"source": [
"loss = gloss.L2Loss()"
]
},
{
"cell_type": "markdown",
"id": "6d014d78",
"metadata": {},
"source": [
"And then feed two inputs to compute the elementwise loss values."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "781a6be0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[04:48:36] /work/mxnet/src/storage/storage.cc:202: Using Pooled (Naive) StorageManager for CPU\n"
]
},
{
"data": {
"text/plain": [
"array([0.5, 0.5])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.ones((2,))\n",
"y = np.ones((2,)) * 2\n",
"loss(x, y)"
]
},
{
"cell_type": "markdown",
"id": "35e2189a",
"metadata": {},
"source": [
"These values should be equal to the math definition: $0.5\\|x-y\\|^2$."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e5ad0274",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.5, 0.5])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
".5 * (x - y)**2"
]
},
{
"cell_type": "markdown",
"id": "0291add2",
"metadata": {},
"source": [
"Next we show how to use a loss function to compute gradients."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4cb56fea",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2.0564232 2.0480375]\n"
]
}
],
"source": [
"X = np.random.uniform(size=(2, 4))\n",
"net = nn.Dense(1)\n",
"net.initialize()\n",
"with autograd.record():\n",
" l = loss(net(X), y)\n",
"print(l)"
]
},
{
"cell_type": "markdown",
"id": "83b65f53",
"metadata": {},
"source": [
"We can compute the gradients w.r.t. the loss function."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a4c9a26b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-2.0926476 -1.6370986 -1.6798826 -2.2344642]]\n"
]
}
],
"source": [
"l.backward()\n",
"print(net.weight.grad())"
]
},
{
"cell_type": "markdown",
"id": "33b90ca3",
"metadata": {},
"source": [
"## Loss functions\n",
"\n",
"Most commonly used loss functions can be divided into 2 categories: regression and classification.\n",
"\n",
"Let's first visualize several regression losses. We visualize the loss values versus the predicted values with label values fixed to be 0."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9fc23763",
"metadata": {},
"outputs": [],
"source": [
"def plot(x, y):\n",
" display.set_matplotlib_formats('svg')\n",
" plt.plot(x.asnumpy(), y.asnumpy())\n",
" plt.xlabel('x')\n",
" plt.ylabel('loss')\n",
" plt.show()\n",
"\n",
"def show_regression_loss(loss):\n",
" x = np.arange(-5, 5, .1)\n",
" y = loss(x, np.zeros_like(x))\n",
" plot(x, y)\n"
]
},
{
"cell_type": "markdown",
"id": "17aaef2b",
"metadata": {},
"source": [
"Then plot the classification losses with label values fixed to be 1."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1317c399",
"metadata": {},
"outputs": [],
"source": [
"def show_classification_loss(loss):\n",
" x = np.arange(-5, 5, .1)\n",
" y = loss(x, np.ones_like(x))\n",
" plot(x, y)"
]
},
{
"cell_type": "markdown",
"id": "6efdbc2d",
"metadata": {},
"source": [
"#### [L1 Loss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.L1Loss)\n",
"\n",
"L1 Loss, also called Mean Absolute Error, computes the sum of absolute distance between target values and the output of the neural network. It is defined as:\n",
"\n",
"$$ L = \\sum_i \\vert {label}_i - {pred}_i \\vert. $$\n",
"\n",
"It is a non-smooth function that can lead to non-convergence. It creates the same gradient for small and large loss values, which can be problematic for the learning process."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "862a2d15",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_191919/841998161.py:2: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`\n",
" display.set_matplotlib_formats('svg')\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"398.560625pt\" height=\"310.86825pt\" viewBox=\"0 0 398.560625 310.86825\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
" <metadata>\n",
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
" <cc:Work>\n",
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
" <dc:date>2023-01-05T04:48:36.327733</dc:date>\n",
" <dc:format>image/svg+xml</dc:format>\n",
" <dc:creator>\n",
" <cc:Agent>\n",
" <dc:title>Matplotlib v3.6.2, https://matplotlib.org/</dc:title>\n",
" </cc:Agent>\n",
" </dc:creator>\n",
" </cc:Work>\n",
" </rdf:RDF>\n",
" </metadata>\n",
" <defs>\n",
" <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
" </defs>\n",
" <g id=\"figure_1\">\n",
" <g id=\"patch_1\">\n",
" <path d=\"M 0 310.86825 \n",
"L 398.560625 310.86825 \n",
"L 398.560625 0 \n",
"L 0 0 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"axes_1\">\n",
" <g id=\"patch_2\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 391.360625 273.312 \n",
"L 391.360625 7.2 \n",
"L 34.240625 7.2 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <defs>\n",
" <path id=\"md875770d64\" d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#md875770d64\" x=\"83.266739\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
" <!-- −4 -->\n",
" <g transform=\"translate(75.895645 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \n",
"L 4684 2272 \n",
"L 4684 1741 \n",
"L 678 1741 \n",
"L 678 2272 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
"L 825 1625 \n",
"L 2419 1625 \n",
"L 2419 4116 \n",
"z\n",
"M 2253 4666 \n",
"L 3047 4666 \n",
"L 3047 1625 \n",
"L 3713 1625 \n",
"L 3713 1100 \n",
"L 3047 1100 \n",
"L 3047 0 \n",
"L 2419 0 \n",
"L 2419 1100 \n",
"L 313 1100 \n",
"L 313 1709 \n",
"L 2253 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-2212\"/>\n",
" <use xlink:href=\"#DejaVuSans-34\" x=\"83.789062\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_2\">\n",
" <g>\n",
" <use xlink:href=\"#md875770d64\" x=\"148.853512\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
" <!-- −2 -->\n",
" <g transform=\"translate(141.482418 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
"L 3431 531 \n",
"L 3431 0 \n",
"L 469 0 \n",
"L 469 531 \n",
"Q 828 903 1448 1529 \n",
"Q 2069 2156 2228 2338 \n",
"Q 2531 2678 2651 2914 \n",
"Q 2772 3150 2772 3378 \n",
"Q 2772 3750 2511 3984 \n",
"Q 2250 4219 1831 4219 \n",
"Q 1534 4219 1204 4116 \n",
"Q 875 4013 500 3803 \n",
"L 500 4441 \n",
"Q 881 4594 1212 4672 \n",
"Q 1544 4750 1819 4750 \n",
"Q 2544 4750 2975 4387 \n",
"Q 3406 4025 3406 3419 \n",
"Q 3406 3131 3298 2873 \n",
"Q 3191 2616 2906 2266 \n",
"Q 2828 2175 2409 1742 \n",
"Q 1991 1309 1228 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-2212\"/>\n",
" <use xlink:href=\"#DejaVuSans-32\" x=\"83.789062\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_3\">\n",
" <g id=\"line2d_3\">\n",
" <g>\n",
" <use xlink:href=\"#md875770d64\" x=\"214.440285\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
" <!-- 0 -->\n",
" <g transform=\"translate(211.259035 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
"Q 1547 4250 1301 3770 \n",
"Q 1056 3291 1056 2328 \n",
"Q 1056 1369 1301 889 \n",
"Q 1547 409 2034 409 \n",
"Q 2525 409 2770 889 \n",
"Q 3016 1369 3016 2328 \n",
"Q 3016 3291 2770 3770 \n",
"Q 2525 4250 2034 4250 \n",
"z\n",
"M 2034 4750 \n",
"Q 2819 4750 3233 4129 \n",
"Q 3647 3509 3647 2328 \n",
"Q 3647 1150 3233 529 \n",
"Q 2819 -91 2034 -91 \n",
"Q 1250 -91 836 529 \n",
"Q 422 1150 422 2328 \n",
"Q 422 3509 836 4129 \n",
"Q 1250 4750 2034 4750 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_4\">\n",
" <g id=\"line2d_4\">\n",
" <g>\n",
" <use xlink:href=\"#md875770d64\" x=\"280.027058\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_4\">\n",
" <!-- 2 -->\n",
" <g transform=\"translate(276.845808 287.910437) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_5\">\n",
" <g id=\"line2d_5\">\n",
" <g>\n",
" <use xlink:href=\"#md875770d64\" x=\"345.613831\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_5\">\n",
" <!-- 4 -->\n",
" <g transform=\"translate(342.432581 287.910437) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_6\">\n",
" <!-- x -->\n",
" <g transform=\"translate(209.84125 301.588562) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-78\" d=\"M 3513 3500 \n",
"L 2247 1797 \n",
"L 3578 0 \n",
"L 2900 0 \n",
"L 1881 1375 \n",
"L 863 0 \n",
"L 184 0 \n",
"L 1544 1831 \n",
"L 300 3500 \n",
"L 978 3500 \n",
"L 1906 2253 \n",
"L 2834 3500 \n",
"L 3513 3500 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-78\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"matplotlib.axis_2\">\n",
" <g id=\"ytick_1\">\n",
" <g id=\"line2d_6\">\n",
" <defs>\n",
" <path id=\"m0c1bb61327\" d=\"M 0 0 \n",
"L -3.5 0 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m0c1bb61327\" x=\"34.240625\" y=\"261.216\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_7\">\n",
" <!-- 0 -->\n",
" <g transform=\"translate(20.878125 265.015219) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_2\">\n",
" <g id=\"line2d_7\">\n",
" <g>\n",
" <use xlink:href=\"#m0c1bb61327\" x=\"34.240625\" y=\"212.832\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_8\">\n",
" <!-- 1 -->\n",
" <g transform=\"translate(20.878125 216.631219) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
"L 1825 531 \n",
"L 1825 4091 \n",
"L 703 3866 \n",
"L 703 4441 \n",
"L 1819 4666 \n",
"L 2450 4666 \n",
"L 2450 531 \n",
"L 3481 531 \n",
"L 3481 0 \n",
"L 794 0 \n",
"L 794 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_3\">\n",
" <g id=\"line2d_8\">\n",
" <g>\n",
" <use xlink:href=\"#m0c1bb61327\" x=\"34.240625\" y=\"164.448\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_9\">\n",
" <!-- 2 -->\n",
" <g transform=\"translate(20.878125 168.247219) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_4\">\n",
" <g id=\"line2d_9\">\n",
" <g>\n",
" <use xlink:href=\"#m0c1bb61327\" x=\"34.240625\" y=\"116.064\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_10\">\n",
" <!-- 3 -->\n",
" <g transform=\"translate(20.878125 119.863219) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-33\" d=\"M 2597 2516 \n",
"Q 3050 2419 3304 2112 \n",
"Q 3559 1806 3559 1356 \n",
"Q 3559 666 3084 287 \n",
"Q 2609 -91 1734 -91 \n",
"Q 1441 -91 1130 -33 \n",
"Q 819 25 488 141 \n",
"L 488 750 \n",
"Q 750 597 1062 519 \n",
"Q 1375 441 1716 441 \n",
"Q 2309 441 2620 675 \n",
"Q 2931 909 2931 1356 \n",
"Q 2931 1769 2642 2001 \n",
"Q 2353 2234 1838 2234 \n",
"L 1294 2234 \n",
"L 1294 2753 \n",
"L 1863 2753 \n",
"Q 2328 2753 2575 2939 \n",
"Q 2822 3125 2822 3475 \n",
"Q 2822 3834 2567 4026 \n",
"Q 2313 4219 1838 4219 \n",
"Q 1578 4219 1281 4162 \n",
"Q 984 4106 628 3988 \n",
"L 628 4550 \n",
"Q 988 4650 1302 4700 \n",
"Q 1616 4750 1894 4750 \n",
"Q 2613 4750 3031 4423 \n",
"Q 3450 4097 3450 3541 \n",
"Q 3450 3153 3228 2886 \n",
"Q 3006 2619 2597 2516 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-33\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_5\">\n",
" <g id=\"line2d_10\">\n",
" <g>\n",
" <use xlink:href=\"#m0c1bb61327\" x=\"34.240625\" y=\"67.68\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_11\">\n",
" <!-- 4 -->\n",
" <g transform=\"translate(20.878125 71.479219) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_6\">\n",
" <g id=\"line2d_11\">\n",
" <g>\n",
" <use xlink:href=\"#m0c1bb61327\" x=\"34.240625\" y=\"19.296\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_12\">\n",
" <!-- 5 -->\n",
" <g transform=\"translate(20.878125 23.095219) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
"L 3169 4666 \n",
"L 3169 4134 \n",
"L 1269 4134 \n",
"L 1269 2991 \n",
"Q 1406 3038 1543 3061 \n",
"Q 1681 3084 1819 3084 \n",
"Q 2600 3084 3056 2656 \n",
"Q 3513 2228 3513 1497 \n",
"Q 3513 744 3044 326 \n",
"Q 2575 -91 1722 -91 \n",
"Q 1428 -91 1123 -41 \n",
"Q 819 9 494 109 \n",
"L 494 744 \n",
"Q 775 591 1075 516 \n",
"Q 1375 441 1709 441 \n",
"Q 2250 441 2565 725 \n",
"Q 2881 1009 2881 1497 \n",
"Q 2881 1984 2565 2268 \n",
"Q 2250 2553 1709 2553 \n",
"Q 1456 2553 1204 2497 \n",
"Q 953 2441 691 2322 \n",
"L 691 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_13\">\n",
" <!-- loss -->\n",
" <g transform=\"translate(14.798438 149.913812) rotate(-90) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
"L 1178 4863 \n",
"L 1178 0 \n",
"L 603 0 \n",
"L 603 4863 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
"Q 1497 3097 1228 2736 \n",
"Q 959 2375 959 1747 \n",
"Q 959 1119 1226 758 \n",
"Q 1494 397 1959 397 \n",
"Q 2419 397 2687 759 \n",
"Q 2956 1122 2956 1747 \n",
"Q 2956 2369 2687 2733 \n",
"Q 2419 3097 1959 3097 \n",
"z\n",
"M 1959 3584 \n",
"Q 2709 3584 3137 3096 \n",
"Q 3566 2609 3566 1747 \n",
"Q 3566 888 3137 398 \n",
"Q 2709 -91 1959 -91 \n",
"Q 1206 -91 779 398 \n",
"Q 353 888 353 1747 \n",
"Q 353 2609 779 3096 \n",
"Q 1206 3584 1959 3584 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
"L 2834 2853 \n",
"Q 2591 2978 2328 3040 \n",
"Q 2066 3103 1784 3103 \n",
"Q 1356 3103 1142 2972 \n",
"Q 928 2841 928 2578 \n",
"Q 928 2378 1081 2264 \n",
"Q 1234 2150 1697 2047 \n",
"L 1894 2003 \n",
"Q 2506 1872 2764 1633 \n",
"Q 3022 1394 3022 966 \n",
"Q 3022 478 2636 193 \n",
"Q 2250 -91 1575 -91 \n",
"Q 1294 -91 989 -36 \n",
"Q 684 19 347 128 \n",
"L 347 722 \n",
"Q 666 556 975 473 \n",
"Q 1284 391 1588 391 \n",
"Q 1994 391 2212 530 \n",
"Q 2431 669 2431 922 \n",
"Q 2431 1156 2273 1281 \n",
"Q 2116 1406 1581 1522 \n",
"L 1381 1569 \n",
"Q 847 1681 609 1914 \n",
"Q 372 2147 372 2553 \n",
"Q 372 3047 722 3315 \n",
"Q 1072 3584 1716 3584 \n",
"Q 2034 3584 2315 3537 \n",
"Q 2597 3491 2834 3397 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-6c\"/>\n",
" <use xlink:href=\"#DejaVuSans-6f\" x=\"27.783203\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"88.964844\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"141.064453\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_12\">\n",
" <path d=\"M 50.473352 19.296 \n",
"L 53.752688 24.134395 \n",
"L 57.032023 28.972791 \n",
"L 60.311374 33.811209 \n",
"L 63.59071 38.649605 \n",
"L 66.870046 43.488 \n",
"L 70.149381 48.326395 \n",
"L 73.428717 53.164791 \n",
"L 76.708068 58.003209 \n",
"L 79.987403 62.841605 \n",
"L 83.266739 67.68 \n",
"L 86.546074 72.518395 \n",
"L 89.825418 77.356802 \n",
"L 93.104761 82.195209 \n",
"L 96.384097 87.033605 \n",
"L 99.663432 91.872 \n",
"L 102.942768 96.710395 \n",
"L 106.222111 101.548802 \n",
"L 109.501454 106.387209 \n",
"L 112.78079 111.225605 \n",
"L 116.060125 116.064 \n",
"L 119.339469 120.902407 \n",
"L 122.618804 125.740802 \n",
"L 125.89814 130.579198 \n",
"L 129.177483 135.417605 \n",
"L 132.456819 140.256 \n",
"L 135.736162 145.094407 \n",
"L 139.015497 149.932802 \n",
"L 142.294833 154.771198 \n",
"L 145.574176 159.609605 \n",
"L 148.853512 164.448 \n",
"L 152.132855 169.286407 \n",
"L 155.412191 174.124802 \n",
"L 158.691526 178.963198 \n",
"L 161.97087 183.801605 \n",
"L 165.250205 188.64 \n",
"L 168.529548 193.478407 \n",
"L 171.808884 198.316802 \n",
"L 175.08822 203.155198 \n",
"L 178.367563 207.993605 \n",
"L 181.646898 212.832 \n",
"L 184.926234 217.670395 \n",
"L 188.205585 222.508814 \n",
"L 191.484921 227.347209 \n",
"L 194.764256 232.185605 \n",
"L 198.043592 237.024 \n",
"L 201.322927 241.862395 \n",
"L 204.602278 246.700814 \n",
"L 207.881614 251.539209 \n",
"L 211.160949 256.377605 \n",
"L 214.440285 261.216 \n",
"L 217.71962 256.377605 \n",
"L 220.998972 251.539186 \n",
"L 224.278307 246.700791 \n",
"L 227.557643 241.862395 \n",
"L 230.836978 237.024 \n",
"L 234.116314 232.185605 \n",
"L 237.395665 227.347186 \n",
"L 240.675 222.508791 \n",
"L 243.954336 217.670395 \n",
"L 247.233671 212.832 \n",
"L 250.513007 207.993605 \n",
"L 253.792358 203.155186 \n",
"L 257.071694 198.316791 \n",
"L 260.351029 193.478395 \n",
"L 263.630365 188.64 \n",
"L 266.9097 183.801605 \n",
"L 270.189051 178.963186 \n",
"L 273.468387 174.124791 \n",
"L 276.747722 169.286395 \n",
"L 280.027058 164.448 \n",
"L 283.306394 159.609605 \n",
"L 286.585745 154.771186 \n",
"L 289.86508 149.932791 \n",
"L 293.144416 145.094395 \n",
"L 296.423751 140.256 \n",
"L 299.703087 135.417605 \n",
"L 302.982438 130.579186 \n",
"L 306.261773 125.740791 \n",
"L 309.541109 120.902395 \n",
"L 312.820445 116.064 \n",
"L 316.099796 111.225582 \n",
"L 319.379116 106.387209 \n",
"L 322.658467 101.548791 \n",
"L 325.937818 96.710372 \n",
"L 329.217138 91.872 \n",
"L 332.496489 87.033582 \n",
"L 335.775809 82.195209 \n",
"L 339.05516 77.356791 \n",
"L 342.334511 72.518372 \n",
"L 345.613831 67.68 \n",
"L 348.893182 62.841582 \n",
"L 352.172502 58.003209 \n",
"L 355.451853 53.164791 \n",
"L 358.731204 48.326372 \n",
"L 362.010524 43.488 \n",
"L 365.289876 38.649582 \n",
"L 368.569195 33.811209 \n",
"L 371.848547 28.972791 \n",
"L 375.127898 24.134372 \n",
"\" clip-path=\"url(#pebd54ca7fc)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_3\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 34.240625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 391.360625 273.312 \n",
"L 391.360625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_5\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 391.360625 273.312 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_6\">\n",
" <path d=\"M 34.240625 7.2 \n",
"L 391.360625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"pebd54ca7fc\">\n",
" <rect x=\"34.240625\" y=\"7.2\" width=\"357.12\" height=\"266.112\"/>\n",
" </clipPath>\n",
" </defs>\n",
"</svg>\n"
],
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_regression_loss(gloss.L1Loss())"
]
},
{
"cell_type": "markdown",
"id": "71d29493",
"metadata": {},
"source": [
"#### [L2 Loss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.L2Loss)\n",
"\n",
"L2Loss, also called Mean Squared Error, is a regression loss function that computes the squared distances between the target values and the output of the neural network. It is defined as:\n",
"\n",
"$$ L = \\frac{1}{2} \\sum_i \\vert {label}_i - {pred}_i \\vert^2. $$\n",
"\n",
"Compared to L1, L2 loss it is a smooth function and it creates larger gradients for large loss values. However due to the squaring it puts high weight on outliers."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "86ce4190",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_191919/841998161.py:2: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`\n",
" display.set_matplotlib_formats('svg')\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"404.923125pt\" height=\"310.86825pt\" viewBox=\"0 0 404.923125 310.86825\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
" <metadata>\n",
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
" <cc:Work>\n",
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
" <dc:date>2023-01-05T04:48:36.427316</dc:date>\n",
" <dc:format>image/svg+xml</dc:format>\n",
" <dc:creator>\n",
" <cc:Agent>\n",
" <dc:title>Matplotlib v3.6.2, https://matplotlib.org/</dc:title>\n",
" </cc:Agent>\n",
" </dc:creator>\n",
" </cc:Work>\n",
" </rdf:RDF>\n",
" </metadata>\n",
" <defs>\n",
" <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
" </defs>\n",
" <g id=\"figure_1\">\n",
" <g id=\"patch_1\">\n",
" <path d=\"M 0 310.86825 \n",
"L 404.923125 310.86825 \n",
"L 404.923125 0 \n",
"L 0 0 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"axes_1\">\n",
" <g id=\"patch_2\">\n",
" <path d=\"M 40.603125 273.312 \n",
"L 397.723125 273.312 \n",
"L 397.723125 7.2 \n",
"L 40.603125 7.2 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <defs>\n",
" <path id=\"m32f0481a9e\" d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m32f0481a9e\" x=\"89.629239\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
" <!-- −4 -->\n",
" <g transform=\"translate(82.258145 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \n",
"L 4684 2272 \n",
"L 4684 1741 \n",
"L 678 1741 \n",
"L 678 2272 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
"L 825 1625 \n",
"L 2419 1625 \n",
"L 2419 4116 \n",
"z\n",
"M 2253 4666 \n",
"L 3047 4666 \n",
"L 3047 1625 \n",
"L 3713 1625 \n",
"L 3713 1100 \n",
"L 3047 1100 \n",
"L 3047 0 \n",
"L 2419 0 \n",
"L 2419 1100 \n",
"L 313 1100 \n",
"L 313 1709 \n",
"L 2253 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-2212\"/>\n",
" <use xlink:href=\"#DejaVuSans-34\" x=\"83.789062\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_2\">\n",
" <g>\n",
" <use xlink:href=\"#m32f0481a9e\" x=\"155.216012\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
" <!-- −2 -->\n",
" <g transform=\"translate(147.844918 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
"L 3431 531 \n",
"L 3431 0 \n",
"L 469 0 \n",
"L 469 531 \n",
"Q 828 903 1448 1529 \n",
"Q 2069 2156 2228 2338 \n",
"Q 2531 2678 2651 2914 \n",
"Q 2772 3150 2772 3378 \n",
"Q 2772 3750 2511 3984 \n",
"Q 2250 4219 1831 4219 \n",
"Q 1534 4219 1204 4116 \n",
"Q 875 4013 500 3803 \n",
"L 500 4441 \n",
"Q 881 4594 1212 4672 \n",
"Q 1544 4750 1819 4750 \n",
"Q 2544 4750 2975 4387 \n",
"Q 3406 4025 3406 3419 \n",
"Q 3406 3131 3298 2873 \n",
"Q 3191 2616 2906 2266 \n",
"Q 2828 2175 2409 1742 \n",
"Q 1991 1309 1228 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-2212\"/>\n",
" <use xlink:href=\"#DejaVuSans-32\" x=\"83.789062\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_3\">\n",
" <g id=\"line2d_3\">\n",
" <g>\n",
" <use xlink:href=\"#m32f0481a9e\" x=\"220.802785\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
" <!-- 0 -->\n",
" <g transform=\"translate(217.621535 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
"Q 1547 4250 1301 3770 \n",
"Q 1056 3291 1056 2328 \n",
"Q 1056 1369 1301 889 \n",
"Q 1547 409 2034 409 \n",
"Q 2525 409 2770 889 \n",
"Q 3016 1369 3016 2328 \n",
"Q 3016 3291 2770 3770 \n",
"Q 2525 4250 2034 4250 \n",
"z\n",
"M 2034 4750 \n",
"Q 2819 4750 3233 4129 \n",
"Q 3647 3509 3647 2328 \n",
"Q 3647 1150 3233 529 \n",
"Q 2819 -91 2034 -91 \n",
"Q 1250 -91 836 529 \n",
"Q 422 1150 422 2328 \n",
"Q 422 3509 836 4129 \n",
"Q 1250 4750 2034 4750 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_4\">\n",
" <g id=\"line2d_4\">\n",
" <g>\n",
" <use xlink:href=\"#m32f0481a9e\" x=\"286.389558\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_4\">\n",
" <!-- 2 -->\n",
" <g transform=\"translate(283.208308 287.910437) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_5\">\n",
" <g id=\"line2d_5\">\n",
" <g>\n",
" <use xlink:href=\"#m32f0481a9e\" x=\"351.976331\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_5\">\n",
" <!-- 4 -->\n",
" <g transform=\"translate(348.795081 287.910437) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_6\">\n",
" <!-- x -->\n",
" <g transform=\"translate(216.20375 301.588562) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-78\" d=\"M 3513 3500 \n",
"L 2247 1797 \n",
"L 3578 0 \n",
"L 2900 0 \n",
"L 1881 1375 \n",
"L 863 0 \n",
"L 184 0 \n",
"L 1544 1831 \n",
"L 300 3500 \n",
"L 978 3500 \n",
"L 1906 2253 \n",
"L 2834 3500 \n",
"L 3513 3500 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-78\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"matplotlib.axis_2\">\n",
" <g id=\"ytick_1\">\n",
" <g id=\"line2d_6\">\n",
" <defs>\n",
" <path id=\"m2c2ec671b1\" d=\"M 0 0 \n",
"L -3.5 0 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m2c2ec671b1\" x=\"40.603125\" y=\"261.216\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_7\">\n",
" <!-- 0 -->\n",
" <g transform=\"translate(27.240625 265.015219) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_2\">\n",
" <g id=\"line2d_7\">\n",
" <g>\n",
" <use xlink:href=\"#m2c2ec671b1\" x=\"40.603125\" y=\"222.5088\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_8\">\n",
" <!-- 2 -->\n",
" <g transform=\"translate(27.240625 226.308019) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_3\">\n",
" <g id=\"line2d_8\">\n",
" <g>\n",
" <use xlink:href=\"#m2c2ec671b1\" x=\"40.603125\" y=\"183.8016\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_9\">\n",
" <!-- 4 -->\n",
" <g transform=\"translate(27.240625 187.600819) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_4\">\n",
" <g id=\"line2d_9\">\n",
" <g>\n",
" <use xlink:href=\"#m2c2ec671b1\" x=\"40.603125\" y=\"145.0944\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_10\">\n",
" <!-- 6 -->\n",
" <g transform=\"translate(27.240625 148.893619) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \n",
"Q 1688 2584 1439 2293 \n",
"Q 1191 2003 1191 1497 \n",
"Q 1191 994 1439 701 \n",
"Q 1688 409 2113 409 \n",
"Q 2538 409 2786 701 \n",
"Q 3034 994 3034 1497 \n",
"Q 3034 2003 2786 2293 \n",
"Q 2538 2584 2113 2584 \n",
"z\n",
"M 3366 4563 \n",
"L 3366 3988 \n",
"Q 3128 4100 2886 4159 \n",
"Q 2644 4219 2406 4219 \n",
"Q 1781 4219 1451 3797 \n",
"Q 1122 3375 1075 2522 \n",
"Q 1259 2794 1537 2939 \n",
"Q 1816 3084 2150 3084 \n",
"Q 2853 3084 3261 2657 \n",
"Q 3669 2231 3669 1497 \n",
"Q 3669 778 3244 343 \n",
"Q 2819 -91 2113 -91 \n",
"Q 1303 -91 875 529 \n",
"Q 447 1150 447 2328 \n",
"Q 447 3434 972 4092 \n",
"Q 1497 4750 2381 4750 \n",
"Q 2619 4750 2861 4703 \n",
"Q 3103 4656 3366 4563 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-36\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_5\">\n",
" <g id=\"line2d_10\">\n",
" <g>\n",
" <use xlink:href=\"#m2c2ec671b1\" x=\"40.603125\" y=\"106.3872\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_11\">\n",
" <!-- 8 -->\n",
" <g transform=\"translate(27.240625 110.186419) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-38\" d=\"M 2034 2216 \n",
"Q 1584 2216 1326 1975 \n",
"Q 1069 1734 1069 1313 \n",
"Q 1069 891 1326 650 \n",
"Q 1584 409 2034 409 \n",
"Q 2484 409 2743 651 \n",
"Q 3003 894 3003 1313 \n",
"Q 3003 1734 2745 1975 \n",
"Q 2488 2216 2034 2216 \n",
"z\n",
"M 1403 2484 \n",
"Q 997 2584 770 2862 \n",
"Q 544 3141 544 3541 \n",
"Q 544 4100 942 4425 \n",
"Q 1341 4750 2034 4750 \n",
"Q 2731 4750 3128 4425 \n",
"Q 3525 4100 3525 3541 \n",
"Q 3525 3141 3298 2862 \n",
"Q 3072 2584 2669 2484 \n",
"Q 3125 2378 3379 2068 \n",
"Q 3634 1759 3634 1313 \n",
"Q 3634 634 3220 271 \n",
"Q 2806 -91 2034 -91 \n",
"Q 1263 -91 848 271 \n",
"Q 434 634 434 1313 \n",
"Q 434 1759 690 2068 \n",
"Q 947 2378 1403 2484 \n",
"z\n",
"M 1172 3481 \n",
"Q 1172 3119 1398 2916 \n",
"Q 1625 2713 2034 2713 \n",
"Q 2441 2713 2670 2916 \n",
"Q 2900 3119 2900 3481 \n",
"Q 2900 3844 2670 4047 \n",
"Q 2441 4250 2034 4250 \n",
"Q 1625 4250 1398 4047 \n",
"Q 1172 3844 1172 3481 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-38\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_6\">\n",
" <g id=\"line2d_11\">\n",
" <g>\n",
" <use xlink:href=\"#m2c2ec671b1\" x=\"40.603125\" y=\"67.68\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_12\">\n",
" <!-- 10 -->\n",
" <g transform=\"translate(20.878125 71.479219) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
"L 1825 531 \n",
"L 1825 4091 \n",
"L 703 3866 \n",
"L 703 4441 \n",
"L 1819 4666 \n",
"L 2450 4666 \n",
"L 2450 531 \n",
"L 3481 531 \n",
"L 3481 0 \n",
"L 794 0 \n",
"L 794 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_7\">\n",
" <g id=\"line2d_12\">\n",
" <g>\n",
" <use xlink:href=\"#m2c2ec671b1\" x=\"40.603125\" y=\"28.9728\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_13\">\n",
" <!-- 12 -->\n",
" <g transform=\"translate(20.878125 32.772019) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" <use xlink:href=\"#DejaVuSans-32\" x=\"63.623047\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_14\">\n",
" <!-- loss -->\n",
" <g transform=\"translate(14.798438 149.913812) rotate(-90) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
"L 1178 4863 \n",
"L 1178 0 \n",
"L 603 0 \n",
"L 603 4863 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
"Q 1497 3097 1228 2736 \n",
"Q 959 2375 959 1747 \n",
"Q 959 1119 1226 758 \n",
"Q 1494 397 1959 397 \n",
"Q 2419 397 2687 759 \n",
"Q 2956 1122 2956 1747 \n",
"Q 2956 2369 2687 2733 \n",
"Q 2419 3097 1959 3097 \n",
"z\n",
"M 1959 3584 \n",
"Q 2709 3584 3137 3096 \n",
"Q 3566 2609 3566 1747 \n",
"Q 3566 888 3137 398 \n",
"Q 2709 -91 1959 -91 \n",
"Q 1206 -91 779 398 \n",
"Q 353 888 353 1747 \n",
"Q 353 2609 779 3096 \n",
"Q 1206 3584 1959 3584 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
"L 2834 2853 \n",
"Q 2591 2978 2328 3040 \n",
"Q 2066 3103 1784 3103 \n",
"Q 1356 3103 1142 2972 \n",
"Q 928 2841 928 2578 \n",
"Q 928 2378 1081 2264 \n",
"Q 1234 2150 1697 2047 \n",
"L 1894 2003 \n",
"Q 2506 1872 2764 1633 \n",
"Q 3022 1394 3022 966 \n",
"Q 3022 478 2636 193 \n",
"Q 2250 -91 1575 -91 \n",
"Q 1294 -91 989 -36 \n",
"Q 684 19 347 128 \n",
"L 347 722 \n",
"Q 666 556 975 473 \n",
"Q 1284 391 1588 391 \n",
"Q 1994 391 2212 530 \n",
"Q 2431 669 2431 922 \n",
"Q 2431 1156 2273 1281 \n",
"Q 2116 1406 1581 1522 \n",
"L 1381 1569 \n",
"Q 847 1681 609 1914 \n",
"Q 372 2147 372 2553 \n",
"Q 372 3047 722 3315 \n",
"Q 1072 3584 1716 3584 \n",
"Q 2034 3584 2315 3537 \n",
"Q 2597 3491 2834 3397 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-6c\"/>\n",
" <use xlink:href=\"#DejaVuSans-6f\" x=\"27.783203\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"88.964844\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"141.064453\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_13\">\n",
" <path d=\"M 56.835852 19.296 \n",
"L 60.115188 28.87603 \n",
"L 63.394523 38.262519 \n",
"L 66.673874 47.455505 \n",
"L 69.95321 56.454913 \n",
"L 73.232546 65.2608 \n",
"L 76.511881 73.873146 \n",
"L 79.791217 82.291952 \n",
"L 83.070568 90.517272 \n",
"L 86.349903 98.548997 \n",
"L 89.629239 106.3872 \n",
"L 92.908574 114.031862 \n",
"L 96.187918 121.483012 \n",
"L 99.467261 128.740621 \n",
"L 102.746597 135.804681 \n",
"L 106.025932 142.6752 \n",
"L 109.305268 149.352188 \n",
"L 112.584611 155.835654 \n",
"L 115.863954 162.125579 \n",
"L 119.14329 168.221955 \n",
"L 122.422625 174.1248 \n",
"L 125.701969 179.834123 \n",
"L 128.981304 185.349891 \n",
"L 132.26064 190.672124 \n",
"L 135.539983 195.800839 \n",
"L 138.819319 200.736 \n",
"L 142.098662 205.477639 \n",
"L 145.377997 210.025728 \n",
"L 148.657333 214.380287 \n",
"L 151.936676 218.541318 \n",
"L 155.216012 222.5088 \n",
"L 158.495355 226.282758 \n",
"L 161.774691 229.86317 \n",
"L 165.054026 233.250047 \n",
"L 168.33337 236.443395 \n",
"L 171.612705 239.4432 \n",
"L 174.892048 242.249476 \n",
"L 178.171384 244.86221 \n",
"L 181.45072 247.281407 \n",
"L 184.730063 249.507074 \n",
"L 188.009398 251.5392 \n",
"L 191.288734 253.37779 \n",
"L 194.568085 255.022852 \n",
"L 197.847421 256.474371 \n",
"L 201.126756 257.732353 \n",
"L 204.406092 258.7968 \n",
"L 207.685427 259.667711 \n",
"L 210.964778 260.34509 \n",
"L 214.244114 260.828929 \n",
"L 217.523449 261.119232 \n",
"L 220.802785 261.216 \n",
"L 224.08212 261.119232 \n",
"L 227.361472 260.828927 \n",
"L 230.640807 260.345087 \n",
"L 233.920143 259.667711 \n",
"L 237.199478 258.7968 \n",
"L 240.478814 257.732353 \n",
"L 243.758165 256.474364 \n",
"L 247.0375 255.022845 \n",
"L 250.316836 253.37779 \n",
"L 253.596171 251.5392 \n",
"L 256.875507 249.507074 \n",
"L 260.154858 247.281402 \n",
"L 263.434194 244.862203 \n",
"L 266.713529 242.249469 \n",
"L 269.992865 239.4432 \n",
"L 273.2722 236.443395 \n",
"L 276.551551 233.250038 \n",
"L 279.830887 229.863161 \n",
"L 283.110222 226.282748 \n",
"L 286.389558 222.5088 \n",
"L 289.668894 218.541318 \n",
"L 292.948245 214.380277 \n",
"L 296.22758 210.025719 \n",
"L 299.506916 205.47763 \n",
"L 302.786251 200.736 \n",
"L 306.065587 195.800839 \n",
"L 309.344938 190.672115 \n",
"L 312.624273 185.349877 \n",
"L 315.903609 179.834104 \n",
"L 319.182945 174.1248 \n",
"L 322.462296 168.221928 \n",
"L 325.741616 162.125579 \n",
"L 329.020967 155.835635 \n",
"L 332.300318 149.352151 \n",
"L 335.579638 142.6752 \n",
"L 338.858989 135.804644 \n",
"L 342.138309 128.740621 \n",
"L 345.41766 121.482994 \n",
"L 348.697011 114.031825 \n",
"L 351.976331 106.3872 \n",
"L 355.255682 98.54896 \n",
"L 358.535002 90.517272 \n",
"L 361.814353 82.291952 \n",
"L 365.093704 73.873109 \n",
"L 368.373024 65.2608 \n",
"L 371.652376 56.454877 \n",
"L 374.931695 47.455505 \n",
"L 378.211047 38.262519 \n",
"L 381.490398 28.875974 \n",
"\" clip-path=\"url(#pacac1234b9)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_3\">\n",
" <path d=\"M 40.603125 273.312 \n",
"L 40.603125 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 397.723125 273.312 \n",
"L 397.723125 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_5\">\n",
" <path d=\"M 40.603125 273.312 \n",
"L 397.723125 273.312 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_6\">\n",
" <path d=\"M 40.603125 7.2 \n",
"L 397.723125 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"pacac1234b9\">\n",
" <rect x=\"40.603125\" y=\"7.2\" width=\"357.12\" height=\"266.112\"/>\n",
" </clipPath>\n",
" </defs>\n",
"</svg>\n"
],
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_regression_loss(gloss.L2Loss())"
]
},
{
"cell_type": "markdown",
"id": "c06d6cf9",
"metadata": {},
"source": [
"#### [Huber Loss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.HuberLoss)\n",
"\n",
"HuberLoss combines advantages of L1 and L2 loss. It calculates a smoothed L1 loss that is equal to L1 if the absolute error exceeds a threshold $$\\rho$$, otherwise it is equal to L2. It is defined as:\n",
"$$\n",
"\\begin{split}L = \\sum_i \\begin{cases} \\frac{1}{2 {rho}} ({label}_i - {pred}_i)^2 &\n",
" \\text{ if } |{label}_i - {pred}_i| < {rho} \\\\\n",
" |{label}_i - {pred}_i| - \\frac{{rho}}{2} &\n",
" \\text{ otherwise }\n",
" \\end{cases}\\end{split}\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "ab9fbf2d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_191919/841998161.py:2: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`\n",
" display.set_matplotlib_formats('svg')\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"398.560625pt\" height=\"310.86825pt\" viewBox=\"0 0 398.560625 310.86825\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
" <metadata>\n",
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
" <cc:Work>\n",
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
" <dc:date>2023-01-05T04:48:36.521757</dc:date>\n",
" <dc:format>image/svg+xml</dc:format>\n",
" <dc:creator>\n",
" <cc:Agent>\n",
" <dc:title>Matplotlib v3.6.2, https://matplotlib.org/</dc:title>\n",
" </cc:Agent>\n",
" </dc:creator>\n",
" </cc:Work>\n",
" </rdf:RDF>\n",
" </metadata>\n",
" <defs>\n",
" <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
" </defs>\n",
" <g id=\"figure_1\">\n",
" <g id=\"patch_1\">\n",
" <path d=\"M 0 310.86825 \n",
"L 398.560625 310.86825 \n",
"L 398.560625 0 \n",
"L 0 0 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"axes_1\">\n",
" <g id=\"patch_2\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 391.360625 273.312 \n",
"L 391.360625 7.2 \n",
"L 34.240625 7.2 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <defs>\n",
" <path id=\"m547a1ccc8d\" d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m547a1ccc8d\" x=\"83.266739\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
" <!-- −4 -->\n",
" <g transform=\"translate(75.895645 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \n",
"L 4684 2272 \n",
"L 4684 1741 \n",
"L 678 1741 \n",
"L 678 2272 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
"L 825 1625 \n",
"L 2419 1625 \n",
"L 2419 4116 \n",
"z\n",
"M 2253 4666 \n",
"L 3047 4666 \n",
"L 3047 1625 \n",
"L 3713 1625 \n",
"L 3713 1100 \n",
"L 3047 1100 \n",
"L 3047 0 \n",
"L 2419 0 \n",
"L 2419 1100 \n",
"L 313 1100 \n",
"L 313 1709 \n",
"L 2253 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-2212\"/>\n",
" <use xlink:href=\"#DejaVuSans-34\" x=\"83.789062\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_2\">\n",
" <g>\n",
" <use xlink:href=\"#m547a1ccc8d\" x=\"148.853512\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
" <!-- −2 -->\n",
" <g transform=\"translate(141.482418 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
"L 3431 531 \n",
"L 3431 0 \n",
"L 469 0 \n",
"L 469 531 \n",
"Q 828 903 1448 1529 \n",
"Q 2069 2156 2228 2338 \n",
"Q 2531 2678 2651 2914 \n",
"Q 2772 3150 2772 3378 \n",
"Q 2772 3750 2511 3984 \n",
"Q 2250 4219 1831 4219 \n",
"Q 1534 4219 1204 4116 \n",
"Q 875 4013 500 3803 \n",
"L 500 4441 \n",
"Q 881 4594 1212 4672 \n",
"Q 1544 4750 1819 4750 \n",
"Q 2544 4750 2975 4387 \n",
"Q 3406 4025 3406 3419 \n",
"Q 3406 3131 3298 2873 \n",
"Q 3191 2616 2906 2266 \n",
"Q 2828 2175 2409 1742 \n",
"Q 1991 1309 1228 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-2212\"/>\n",
" <use xlink:href=\"#DejaVuSans-32\" x=\"83.789062\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_3\">\n",
" <g id=\"line2d_3\">\n",
" <g>\n",
" <use xlink:href=\"#m547a1ccc8d\" x=\"214.440285\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
" <!-- 0 -->\n",
" <g transform=\"translate(211.259035 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
"Q 1547 4250 1301 3770 \n",
"Q 1056 3291 1056 2328 \n",
"Q 1056 1369 1301 889 \n",
"Q 1547 409 2034 409 \n",
"Q 2525 409 2770 889 \n",
"Q 3016 1369 3016 2328 \n",
"Q 3016 3291 2770 3770 \n",
"Q 2525 4250 2034 4250 \n",
"z\n",
"M 2034 4750 \n",
"Q 2819 4750 3233 4129 \n",
"Q 3647 3509 3647 2328 \n",
"Q 3647 1150 3233 529 \n",
"Q 2819 -91 2034 -91 \n",
"Q 1250 -91 836 529 \n",
"Q 422 1150 422 2328 \n",
"Q 422 3509 836 4129 \n",
"Q 1250 4750 2034 4750 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_4\">\n",
" <g id=\"line2d_4\">\n",
" <g>\n",
" <use xlink:href=\"#m547a1ccc8d\" x=\"280.027058\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_4\">\n",
" <!-- 2 -->\n",
" <g transform=\"translate(276.845808 287.910437) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_5\">\n",
" <g id=\"line2d_5\">\n",
" <g>\n",
" <use xlink:href=\"#m547a1ccc8d\" x=\"345.613831\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_5\">\n",
" <!-- 4 -->\n",
" <g transform=\"translate(342.432581 287.910437) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_6\">\n",
" <!-- x -->\n",
" <g transform=\"translate(209.84125 301.588562) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-78\" d=\"M 3513 3500 \n",
"L 2247 1797 \n",
"L 3578 0 \n",
"L 2900 0 \n",
"L 1881 1375 \n",
"L 863 0 \n",
"L 184 0 \n",
"L 1544 1831 \n",
"L 300 3500 \n",
"L 978 3500 \n",
"L 1906 2253 \n",
"L 2834 3500 \n",
"L 3513 3500 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-78\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"matplotlib.axis_2\">\n",
" <g id=\"ytick_1\">\n",
" <g id=\"line2d_6\">\n",
" <defs>\n",
" <path id=\"m7170cebb39\" d=\"M 0 0 \n",
"L -3.5 0 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m7170cebb39\" x=\"34.240625\" y=\"261.216\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_7\">\n",
" <!-- 0 -->\n",
" <g transform=\"translate(20.878125 265.015219) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_2\">\n",
" <g id=\"line2d_7\">\n",
" <g>\n",
" <use xlink:href=\"#m7170cebb39\" x=\"34.240625\" y=\"207.456\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_8\">\n",
" <!-- 1 -->\n",
" <g transform=\"translate(20.878125 211.255219) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
"L 1825 531 \n",
"L 1825 4091 \n",
"L 703 3866 \n",
"L 703 4441 \n",
"L 1819 4666 \n",
"L 2450 4666 \n",
"L 2450 531 \n",
"L 3481 531 \n",
"L 3481 0 \n",
"L 794 0 \n",
"L 794 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_3\">\n",
" <g id=\"line2d_8\">\n",
" <g>\n",
" <use xlink:href=\"#m7170cebb39\" x=\"34.240625\" y=\"153.696\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_9\">\n",
" <!-- 2 -->\n",
" <g transform=\"translate(20.878125 157.495219) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_4\">\n",
" <g id=\"line2d_9\">\n",
" <g>\n",
" <use xlink:href=\"#m7170cebb39\" x=\"34.240625\" y=\"99.936\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_10\">\n",
" <!-- 3 -->\n",
" <g transform=\"translate(20.878125 103.735219) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-33\" d=\"M 2597 2516 \n",
"Q 3050 2419 3304 2112 \n",
"Q 3559 1806 3559 1356 \n",
"Q 3559 666 3084 287 \n",
"Q 2609 -91 1734 -91 \n",
"Q 1441 -91 1130 -33 \n",
"Q 819 25 488 141 \n",
"L 488 750 \n",
"Q 750 597 1062 519 \n",
"Q 1375 441 1716 441 \n",
"Q 2309 441 2620 675 \n",
"Q 2931 909 2931 1356 \n",
"Q 2931 1769 2642 2001 \n",
"Q 2353 2234 1838 2234 \n",
"L 1294 2234 \n",
"L 1294 2753 \n",
"L 1863 2753 \n",
"Q 2328 2753 2575 2939 \n",
"Q 2822 3125 2822 3475 \n",
"Q 2822 3834 2567 4026 \n",
"Q 2313 4219 1838 4219 \n",
"Q 1578 4219 1281 4162 \n",
"Q 984 4106 628 3988 \n",
"L 628 4550 \n",
"Q 988 4650 1302 4700 \n",
"Q 1616 4750 1894 4750 \n",
"Q 2613 4750 3031 4423 \n",
"Q 3450 4097 3450 3541 \n",
"Q 3450 3153 3228 2886 \n",
"Q 3006 2619 2597 2516 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-33\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_5\">\n",
" <g id=\"line2d_10\">\n",
" <g>\n",
" <use xlink:href=\"#m7170cebb39\" x=\"34.240625\" y=\"46.176\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_11\">\n",
" <!-- 4 -->\n",
" <g transform=\"translate(20.878125 49.975219) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_12\">\n",
" <!-- loss -->\n",
" <g transform=\"translate(14.798438 149.913812) rotate(-90) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
"L 1178 4863 \n",
"L 1178 0 \n",
"L 603 0 \n",
"L 603 4863 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
"Q 1497 3097 1228 2736 \n",
"Q 959 2375 959 1747 \n",
"Q 959 1119 1226 758 \n",
"Q 1494 397 1959 397 \n",
"Q 2419 397 2687 759 \n",
"Q 2956 1122 2956 1747 \n",
"Q 2956 2369 2687 2733 \n",
"Q 2419 3097 1959 3097 \n",
"z\n",
"M 1959 3584 \n",
"Q 2709 3584 3137 3096 \n",
"Q 3566 2609 3566 1747 \n",
"Q 3566 888 3137 398 \n",
"Q 2709 -91 1959 -91 \n",
"Q 1206 -91 779 398 \n",
"Q 353 888 353 1747 \n",
"Q 353 2609 779 3096 \n",
"Q 1206 3584 1959 3584 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
"L 2834 2853 \n",
"Q 2591 2978 2328 3040 \n",
"Q 2066 3103 1784 3103 \n",
"Q 1356 3103 1142 2972 \n",
"Q 928 2841 928 2578 \n",
"Q 928 2378 1081 2264 \n",
"Q 1234 2150 1697 2047 \n",
"L 1894 2003 \n",
"Q 2506 1872 2764 1633 \n",
"Q 3022 1394 3022 966 \n",
"Q 3022 478 2636 193 \n",
"Q 2250 -91 1575 -91 \n",
"Q 1294 -91 989 -36 \n",
"Q 684 19 347 128 \n",
"L 347 722 \n",
"Q 666 556 975 473 \n",
"Q 1284 391 1588 391 \n",
"Q 1994 391 2212 530 \n",
"Q 2431 669 2431 922 \n",
"Q 2431 1156 2273 1281 \n",
"Q 2116 1406 1581 1522 \n",
"L 1381 1569 \n",
"Q 847 1681 609 1914 \n",
"Q 372 2147 372 2553 \n",
"Q 372 3047 722 3315 \n",
"Q 1072 3584 1716 3584 \n",
"Q 2034 3584 2315 3537 \n",
"Q 2597 3491 2834 3397 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-6c\"/>\n",
" <use xlink:href=\"#DejaVuSans-6f\" x=\"27.783203\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"88.964844\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"141.064453\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_11\">\n",
" <path d=\"M 50.473352 19.296 \n",
"L 53.752688 24.671995 \n",
"L 57.032023 30.04799 \n",
"L 60.311374 35.42401 \n",
"L 63.59071 40.800005 \n",
"L 66.870046 46.176 \n",
"L 70.149381 51.551995 \n",
"L 73.428717 56.92799 \n",
"L 76.708068 62.30401 \n",
"L 79.987403 67.680005 \n",
"L 83.266739 73.056 \n",
"L 86.546074 78.431995 \n",
"L 89.825418 83.808003 \n",
"L 93.104761 89.18401 \n",
"L 96.384097 94.560005 \n",
"L 99.663432 99.936 \n",
"L 102.942768 105.311995 \n",
"L 106.222111 110.688003 \n",
"L 109.501454 116.06401 \n",
"L 112.78079 121.440005 \n",
"L 116.060125 126.816 \n",
"L 119.339469 132.192008 \n",
"L 122.618804 137.568003 \n",
"L 125.89814 142.943997 \n",
"L 129.177483 148.320005 \n",
"L 132.456819 153.696 \n",
"L 135.736162 159.072008 \n",
"L 139.015497 164.448003 \n",
"L 142.294833 169.823997 \n",
"L 145.574176 175.200005 \n",
"L 148.853512 180.576 \n",
"L 152.132855 185.952008 \n",
"L 155.412191 191.328003 \n",
"L 158.691526 196.703997 \n",
"L 161.97087 202.080005 \n",
"L 165.250205 207.456 \n",
"L 168.529548 212.832008 \n",
"L 171.808884 218.208003 \n",
"L 175.08822 223.583997 \n",
"L 178.367563 228.960005 \n",
"L 181.646898 234.336 \n",
"L 184.926234 239.443195 \n",
"L 188.205585 244.012812 \n",
"L 191.484921 248.044807 \n",
"L 194.764256 251.539203 \n",
"L 198.043592 254.496 \n",
"L 201.322927 256.915198 \n",
"L 204.602278 258.796805 \n",
"L 207.881614 260.140802 \n",
"L 211.160949 260.947201 \n",
"L 214.440285 261.216 \n",
"L 217.71962 260.947201 \n",
"L 220.998972 260.140797 \n",
"L 224.278307 258.796797 \n",
"L 227.557643 256.915198 \n",
"L 230.836978 254.496 \n",
"L 234.116314 251.539203 \n",
"L 237.395665 248.044789 \n",
"L 240.675 244.012792 \n",
"L 243.954336 239.443195 \n",
"L 247.233671 234.336 \n",
"L 250.513007 228.960005 \n",
"L 253.792358 223.583985 \n",
"L 257.071694 218.20799 \n",
"L 260.351029 212.831995 \n",
"L 263.630365 207.456 \n",
"L 266.9097 202.080005 \n",
"L 270.189051 196.703985 \n",
"L 273.468387 191.32799 \n",
"L 276.747722 185.951995 \n",
"L 280.027058 180.576 \n",
"L 283.306394 175.200005 \n",
"L 286.585745 169.823985 \n",
"L 289.86508 164.44799 \n",
"L 293.144416 159.071995 \n",
"L 296.423751 153.696 \n",
"L 299.703087 148.320005 \n",
"L 302.982438 142.943985 \n",
"L 306.261773 137.56799 \n",
"L 309.541109 132.191995 \n",
"L 312.820445 126.816 \n",
"L 316.099796 121.439979 \n",
"L 319.379116 116.06401 \n",
"L 322.658467 110.68799 \n",
"L 325.937818 105.311969 \n",
"L 329.217138 99.936 \n",
"L 332.496489 94.559979 \n",
"L 335.775809 89.18401 \n",
"L 339.05516 83.80799 \n",
"L 342.334511 78.431969 \n",
"L 345.613831 73.056 \n",
"L 348.893182 67.679979 \n",
"L 352.172502 62.30401 \n",
"L 355.451853 56.92799 \n",
"L 358.731204 51.551969 \n",
"L 362.010524 46.176 \n",
"L 365.289876 40.799979 \n",
"L 368.569195 35.42401 \n",
"L 371.848547 30.04799 \n",
"L 375.127898 24.671969 \n",
"\" clip-path=\"url(#p91e2447995)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_3\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 34.240625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 391.360625 273.312 \n",
"L 391.360625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_5\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 391.360625 273.312 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_6\">\n",
" <path d=\"M 34.240625 7.2 \n",
"L 391.360625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"p91e2447995\">\n",
" <rect x=\"34.240625\" y=\"7.2\" width=\"357.12\" height=\"266.112\"/>\n",
" </clipPath>\n",
" </defs>\n",
"</svg>\n"
],
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_regression_loss(gloss.HuberLoss(rho=1))"
]
},
{
"cell_type": "markdown",
"id": "d5bdaf1a",
"metadata": {},
"source": [
"An example of where Huber Loss is used can be found in [Deep Q Network](https://openai.com/blog/openai-baselines-dqn/).\n",
"\n",
"#### [Cross Entropy Loss with Sigmoid](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.SigmoidBinaryCrossEntropyLoss)\n",
"\n",
"Binary Cross Entropy is a loss function used for binary classification problems e.g. classifying images into 2 classes. Cross entropy measures the difference between two probability distributions and it is defined as:\n",
"$$\\sum_i -{(y\\log(p) + (1 - y)\\log(1 - p))} $$\n",
"Before the loss is computed a sigmoid activation is applied per default. If your network has `sigmoid` activation as last layer, then you need set ```from_sigmoid``` to False, to avoid applying the sigmoid function twice."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "1d76abeb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_191919/841998161.py:2: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`\n",
" display.set_matplotlib_formats('svg')\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"398.560625pt\" height=\"310.86825pt\" viewBox=\"0 0 398.560625 310.86825\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
" <metadata>\n",
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
" <cc:Work>\n",
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
" <dc:date>2023-01-05T04:48:36.616969</dc:date>\n",
" <dc:format>image/svg+xml</dc:format>\n",
" <dc:creator>\n",
" <cc:Agent>\n",
" <dc:title>Matplotlib v3.6.2, https://matplotlib.org/</dc:title>\n",
" </cc:Agent>\n",
" </dc:creator>\n",
" </cc:Work>\n",
" </rdf:RDF>\n",
" </metadata>\n",
" <defs>\n",
" <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
" </defs>\n",
" <g id=\"figure_1\">\n",
" <g id=\"patch_1\">\n",
" <path d=\"M 0 310.86825 \n",
"L 398.560625 310.86825 \n",
"L 398.560625 0 \n",
"L 0 0 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"axes_1\">\n",
" <g id=\"patch_2\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 391.360625 273.312 \n",
"L 391.360625 7.2 \n",
"L 34.240625 7.2 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <defs>\n",
" <path id=\"m08896495b1\" d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m08896495b1\" x=\"83.266739\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
" <!-- −4 -->\n",
" <g transform=\"translate(75.895645 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \n",
"L 4684 2272 \n",
"L 4684 1741 \n",
"L 678 1741 \n",
"L 678 2272 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
"L 825 1625 \n",
"L 2419 1625 \n",
"L 2419 4116 \n",
"z\n",
"M 2253 4666 \n",
"L 3047 4666 \n",
"L 3047 1625 \n",
"L 3713 1625 \n",
"L 3713 1100 \n",
"L 3047 1100 \n",
"L 3047 0 \n",
"L 2419 0 \n",
"L 2419 1100 \n",
"L 313 1100 \n",
"L 313 1709 \n",
"L 2253 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-2212\"/>\n",
" <use xlink:href=\"#DejaVuSans-34\" x=\"83.789062\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_2\">\n",
" <g>\n",
" <use xlink:href=\"#m08896495b1\" x=\"148.853512\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
" <!-- −2 -->\n",
" <g transform=\"translate(141.482418 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
"L 3431 531 \n",
"L 3431 0 \n",
"L 469 0 \n",
"L 469 531 \n",
"Q 828 903 1448 1529 \n",
"Q 2069 2156 2228 2338 \n",
"Q 2531 2678 2651 2914 \n",
"Q 2772 3150 2772 3378 \n",
"Q 2772 3750 2511 3984 \n",
"Q 2250 4219 1831 4219 \n",
"Q 1534 4219 1204 4116 \n",
"Q 875 4013 500 3803 \n",
"L 500 4441 \n",
"Q 881 4594 1212 4672 \n",
"Q 1544 4750 1819 4750 \n",
"Q 2544 4750 2975 4387 \n",
"Q 3406 4025 3406 3419 \n",
"Q 3406 3131 3298 2873 \n",
"Q 3191 2616 2906 2266 \n",
"Q 2828 2175 2409 1742 \n",
"Q 1991 1309 1228 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-2212\"/>\n",
" <use xlink:href=\"#DejaVuSans-32\" x=\"83.789062\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_3\">\n",
" <g id=\"line2d_3\">\n",
" <g>\n",
" <use xlink:href=\"#m08896495b1\" x=\"214.440285\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
" <!-- 0 -->\n",
" <g transform=\"translate(211.259035 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
"Q 1547 4250 1301 3770 \n",
"Q 1056 3291 1056 2328 \n",
"Q 1056 1369 1301 889 \n",
"Q 1547 409 2034 409 \n",
"Q 2525 409 2770 889 \n",
"Q 3016 1369 3016 2328 \n",
"Q 3016 3291 2770 3770 \n",
"Q 2525 4250 2034 4250 \n",
"z\n",
"M 2034 4750 \n",
"Q 2819 4750 3233 4129 \n",
"Q 3647 3509 3647 2328 \n",
"Q 3647 1150 3233 529 \n",
"Q 2819 -91 2034 -91 \n",
"Q 1250 -91 836 529 \n",
"Q 422 1150 422 2328 \n",
"Q 422 3509 836 4129 \n",
"Q 1250 4750 2034 4750 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_4\">\n",
" <g id=\"line2d_4\">\n",
" <g>\n",
" <use xlink:href=\"#m08896495b1\" x=\"280.027058\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_4\">\n",
" <!-- 2 -->\n",
" <g transform=\"translate(276.845808 287.910437) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_5\">\n",
" <g id=\"line2d_5\">\n",
" <g>\n",
" <use xlink:href=\"#m08896495b1\" x=\"345.613831\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_5\">\n",
" <!-- 4 -->\n",
" <g transform=\"translate(342.432581 287.910437) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_6\">\n",
" <!-- x -->\n",
" <g transform=\"translate(209.84125 301.588562) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-78\" d=\"M 3513 3500 \n",
"L 2247 1797 \n",
"L 3578 0 \n",
"L 2900 0 \n",
"L 1881 1375 \n",
"L 863 0 \n",
"L 184 0 \n",
"L 1544 1831 \n",
"L 300 3500 \n",
"L 978 3500 \n",
"L 1906 2253 \n",
"L 2834 3500 \n",
"L 3513 3500 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-78\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"matplotlib.axis_2\">\n",
" <g id=\"ytick_1\">\n",
" <g id=\"line2d_6\">\n",
" <defs>\n",
" <path id=\"me40cd8de18\" d=\"M 0 0 \n",
"L -3.5 0 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#me40cd8de18\" x=\"34.240625\" y=\"261.575011\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_7\">\n",
" <!-- 0 -->\n",
" <g transform=\"translate(20.878125 265.37423) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_2\">\n",
" <g id=\"line2d_7\">\n",
" <g>\n",
" <use xlink:href=\"#me40cd8de18\" x=\"34.240625\" y=\"213.1842\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_8\">\n",
" <!-- 1 -->\n",
" <g transform=\"translate(20.878125 216.983419) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
"L 1825 531 \n",
"L 1825 4091 \n",
"L 703 3866 \n",
"L 703 4441 \n",
"L 1819 4666 \n",
"L 2450 4666 \n",
"L 2450 531 \n",
"L 3481 531 \n",
"L 3481 0 \n",
"L 794 0 \n",
"L 794 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_3\">\n",
" <g id=\"line2d_8\">\n",
" <g>\n",
" <use xlink:href=\"#me40cd8de18\" x=\"34.240625\" y=\"164.79339\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_9\">\n",
" <!-- 2 -->\n",
" <g transform=\"translate(20.878125 168.592609) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_4\">\n",
" <g id=\"line2d_9\">\n",
" <g>\n",
" <use xlink:href=\"#me40cd8de18\" x=\"34.240625\" y=\"116.40258\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_10\">\n",
" <!-- 3 -->\n",
" <g transform=\"translate(20.878125 120.201798) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-33\" d=\"M 2597 2516 \n",
"Q 3050 2419 3304 2112 \n",
"Q 3559 1806 3559 1356 \n",
"Q 3559 666 3084 287 \n",
"Q 2609 -91 1734 -91 \n",
"Q 1441 -91 1130 -33 \n",
"Q 819 25 488 141 \n",
"L 488 750 \n",
"Q 750 597 1062 519 \n",
"Q 1375 441 1716 441 \n",
"Q 2309 441 2620 675 \n",
"Q 2931 909 2931 1356 \n",
"Q 2931 1769 2642 2001 \n",
"Q 2353 2234 1838 2234 \n",
"L 1294 2234 \n",
"L 1294 2753 \n",
"L 1863 2753 \n",
"Q 2328 2753 2575 2939 \n",
"Q 2822 3125 2822 3475 \n",
"Q 2822 3834 2567 4026 \n",
"Q 2313 4219 1838 4219 \n",
"Q 1578 4219 1281 4162 \n",
"Q 984 4106 628 3988 \n",
"L 628 4550 \n",
"Q 988 4650 1302 4700 \n",
"Q 1616 4750 1894 4750 \n",
"Q 2613 4750 3031 4423 \n",
"Q 3450 4097 3450 3541 \n",
"Q 3450 3153 3228 2886 \n",
"Q 3006 2619 2597 2516 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-33\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_5\">\n",
" <g id=\"line2d_10\">\n",
" <g>\n",
" <use xlink:href=\"#me40cd8de18\" x=\"34.240625\" y=\"68.011769\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_11\">\n",
" <!-- 4 -->\n",
" <g transform=\"translate(20.878125 71.810988) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_6\">\n",
" <g id=\"line2d_11\">\n",
" <g>\n",
" <use xlink:href=\"#me40cd8de18\" x=\"34.240625\" y=\"19.620959\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_12\">\n",
" <!-- 5 -->\n",
" <g transform=\"translate(20.878125 23.420177) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
"L 3169 4666 \n",
"L 3169 4134 \n",
"L 1269 4134 \n",
"L 1269 2991 \n",
"Q 1406 3038 1543 3061 \n",
"Q 1681 3084 1819 3084 \n",
"Q 2600 3084 3056 2656 \n",
"Q 3513 2228 3513 1497 \n",
"Q 3513 744 3044 326 \n",
"Q 2575 -91 1722 -91 \n",
"Q 1428 -91 1123 -41 \n",
"Q 819 9 494 109 \n",
"L 494 744 \n",
"Q 775 591 1075 516 \n",
"Q 1375 441 1709 441 \n",
"Q 2250 441 2565 725 \n",
"Q 2881 1009 2881 1497 \n",
"Q 2881 1984 2565 2268 \n",
"Q 2250 2553 1709 2553 \n",
"Q 1456 2553 1204 2497 \n",
"Q 953 2441 691 2322 \n",
"L 691 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_13\">\n",
" <!-- loss -->\n",
" <g transform=\"translate(14.798438 149.913812) rotate(-90) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
"L 1178 4863 \n",
"L 1178 0 \n",
"L 603 0 \n",
"L 603 4863 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
"Q 1497 3097 1228 2736 \n",
"Q 959 2375 959 1747 \n",
"Q 959 1119 1226 758 \n",
"Q 1494 397 1959 397 \n",
"Q 2419 397 2687 759 \n",
"Q 2956 1122 2956 1747 \n",
"Q 2956 2369 2687 2733 \n",
"Q 2419 3097 1959 3097 \n",
"z\n",
"M 1959 3584 \n",
"Q 2709 3584 3137 3096 \n",
"Q 3566 2609 3566 1747 \n",
"Q 3566 888 3137 398 \n",
"Q 2709 -91 1959 -91 \n",
"Q 1206 -91 779 398 \n",
"Q 353 888 353 1747 \n",
"Q 353 2609 779 3096 \n",
"Q 1206 3584 1959 3584 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
"L 2834 2853 \n",
"Q 2591 2978 2328 3040 \n",
"Q 2066 3103 1784 3103 \n",
"Q 1356 3103 1142 2972 \n",
"Q 928 2841 928 2578 \n",
"Q 928 2378 1081 2264 \n",
"Q 1234 2150 1697 2047 \n",
"L 1894 2003 \n",
"Q 2506 1872 2764 1633 \n",
"Q 3022 1394 3022 966 \n",
"Q 3022 478 2636 193 \n",
"Q 2250 -91 1575 -91 \n",
"Q 1294 -91 989 -36 \n",
"Q 684 19 347 128 \n",
"L 347 722 \n",
"Q 666 556 975 473 \n",
"Q 1284 391 1588 391 \n",
"Q 1994 391 2212 530 \n",
"Q 2431 669 2431 922 \n",
"Q 2431 1156 2273 1281 \n",
"Q 2116 1406 1581 1522 \n",
"L 1381 1569 \n",
"Q 847 1681 609 1914 \n",
"Q 372 2147 372 2553 \n",
"Q 372 3047 722 3315 \n",
"Q 1072 3584 1716 3584 \n",
"Q 2034 3584 2315 3537 \n",
"Q 2597 3491 2834 3397 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-6c\"/>\n",
" <use xlink:href=\"#DejaVuSans-6f\" x=\"27.783203\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"88.964844\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"141.064453\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_12\">\n",
" <path d=\"M 50.473352 19.296 \n",
"L 53.752688 24.101018 \n",
"L 57.032023 28.902506 \n",
"L 60.311374 33.700072 \n",
"L 63.59071 38.493299 \n",
"L 66.870046 43.28175 \n",
"L 70.149381 48.06494 \n",
"L 73.428717 52.842338 \n",
"L 76.708068 57.613345 \n",
"L 79.987403 62.37729 \n",
"L 83.266739 67.133483 \n",
"L 86.546074 71.881104 \n",
"L 89.825418 76.619321 \n",
"L 93.104761 81.347178 \n",
"L 96.384097 86.063625 \n",
"L 99.663432 90.767531 \n",
"L 102.942768 95.45765 \n",
"L 106.222111 100.132655 \n",
"L 109.501454 104.791057 \n",
"L 112.78079 109.431208 \n",
"L 116.060125 114.0514 \n",
"L 119.339469 118.649716 \n",
"L 122.618804 123.224093 \n",
"L 125.89814 127.772315 \n",
"L 129.177483 132.291959 \n",
"L 132.456819 136.780441 \n",
"L 135.736162 141.235004 \n",
"L 139.015497 145.652624 \n",
"L 142.294833 150.030163 \n",
"L 145.574176 154.36423 \n",
"L 148.853512 158.651237 \n",
"L 152.132855 162.887433 \n",
"L 155.412191 167.068845 \n",
"L 158.691526 171.19133 \n",
"L 161.97087 175.250614 \n",
"L 165.250205 179.242244 \n",
"L 168.529548 183.161702 \n",
"L 171.808884 187.004362 \n",
"L 175.08822 190.765586 \n",
"L 178.367563 194.440731 \n",
"L 181.646898 198.025211 \n",
"L 184.926234 201.514565 \n",
"L 188.205585 204.904509 \n",
"L 191.484921 208.190952 \n",
"L 194.764256 211.370132 \n",
"L 198.043592 214.438636 \n",
"L 201.322927 217.39346 \n",
"L 204.602278 220.232075 \n",
"L 207.881614 222.952431 \n",
"L 211.160949 225.553055 \n",
"L 214.440285 228.033057 \n",
"L 217.71962 230.392131 \n",
"L 220.998972 232.630592 \n",
"L 224.278307 234.749316 \n",
"L 227.557643 236.749789 \n",
"L 230.836978 238.634041 \n",
"L 234.116314 240.404614 \n",
"L 237.395665 242.064515 \n",
"L 240.675 243.617152 \n",
"L 243.954336 245.066299 \n",
"L 247.233671 246.416024 \n",
"L 250.513007 247.670621 \n",
"L 253.792358 248.834561 \n",
"L 257.071694 249.912418 \n",
"L 260.351029 250.908835 \n",
"L 263.630365 251.828459 \n",
"L 266.9097 252.675904 \n",
"L 270.189051 253.455711 \n",
"L 273.468387 254.172302 \n",
"L 276.747722 254.829973 \n",
"L 280.027058 255.432862 \n",
"L 283.306394 255.984927 \n",
"L 286.585745 256.489945 \n",
"L 289.86508 256.951489 \n",
"L 293.144416 257.37294 \n",
"L 296.423751 257.757473 \n",
"L 299.703087 258.108066 \n",
"L 302.982438 258.427501 \n",
"L 306.261773 258.718365 \n",
"L 309.541109 258.983065 \n",
"L 312.820445 259.22383 \n",
"L 316.099796 259.442721 \n",
"L 319.379116 259.641636 \n",
"L 322.658467 259.822332 \n",
"L 325.937818 259.986415 \n",
"L 329.217138 260.135364 \n",
"L 332.496489 260.270536 \n",
"L 335.775809 260.393169 \n",
"L 339.05516 260.504402 \n",
"L 342.334511 260.605271 \n",
"L 345.613831 260.696721 \n",
"L 348.893182 260.779619 \n",
"L 352.172502 260.85475 \n",
"L 355.451853 260.922832 \n",
"L 358.731204 260.984518 \n",
"L 362.010524 261.040402 \n",
"L 365.289876 261.091023 \n",
"L 368.569195 261.136873 \n",
"L 371.848547 261.178397 \n",
"L 375.127898 261.216 \n",
"\" clip-path=\"url(#p7936d89d3a)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_3\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 34.240625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 391.360625 273.312 \n",
"L 391.360625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_5\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 391.360625 273.312 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_6\">\n",
" <path d=\"M 34.240625 7.2 \n",
"L 391.360625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"p7936d89d3a\">\n",
" <rect x=\"34.240625\" y=\"7.2\" width=\"357.12\" height=\"266.112\"/>\n",
" </clipPath>\n",
" </defs>\n",
"</svg>\n"
],
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_classification_loss(gloss.SigmoidBinaryCrossEntropyLoss())"
]
},
{
"cell_type": "markdown",
"id": "9a4e320c",
"metadata": {},
"source": [
"#### [Cross Entropy Loss with Softmax](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.SoftmaxCrossEntropyLoss)\n",
"\n",
"In classification, we often apply the\n",
"softmax operator to the predicted outputs to obtain prediction probabilities,\n",
"and then apply the cross entropy loss against the true labels:\n",
"\n",
"$$ \\begin{align}\\begin{aligned}p = \\text{softmax}({pred})\\\\L = -\\sum_i \\sum_j {label}_j \\log p_{ij}\\end{aligned}\\end{align}\n",
"$$\n",
"\n",
"Running these two steps one-by-one, however, may lead to numerical instabilities. The `loss` module provides a single operators with softmax and cross entropy fused to avoid such problem."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "6bb6f7ac",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([9.000123 , 6.0024757])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss = gloss.SoftmaxCrossEntropyLoss()\n",
"x = np.array([[1, 10], [8, 2]])\n",
"y = np.array([0, 1])\n",
"loss(x, y)"
]
},
{
"cell_type": "markdown",
"id": "4f7db011",
"metadata": {},
"source": [
"#### [Hinge Loss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.HingeLoss)\n",
"\n",
"Commonly used in Support Vector Machines (SVMs), Hinge Loss is used to additionally penalize predictions that are correct but fall within a margin between classes (the region around a decision boundary). Unlike `SoftmaxCrossEntropyLoss`, it's rarely used for neural network training. It is defined as:\n",
"\n",
"$$\n",
"L = \\sum_i max(0, {margin} - {pred}_i \\cdot {label}_i)\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "d1330884",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_191919/841998161.py:2: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`\n",
" display.set_matplotlib_formats('svg')\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"398.560625pt\" height=\"310.86825pt\" viewBox=\"0 0 398.560625 310.86825\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
" <metadata>\n",
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
" <cc:Work>\n",
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
" <dc:date>2023-01-05T04:48:36.725056</dc:date>\n",
" <dc:format>image/svg+xml</dc:format>\n",
" <dc:creator>\n",
" <cc:Agent>\n",
" <dc:title>Matplotlib v3.6.2, https://matplotlib.org/</dc:title>\n",
" </cc:Agent>\n",
" </dc:creator>\n",
" </cc:Work>\n",
" </rdf:RDF>\n",
" </metadata>\n",
" <defs>\n",
" <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
" </defs>\n",
" <g id=\"figure_1\">\n",
" <g id=\"patch_1\">\n",
" <path d=\"M 0 310.86825 \n",
"L 398.560625 310.86825 \n",
"L 398.560625 0 \n",
"L 0 0 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"axes_1\">\n",
" <g id=\"patch_2\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 391.360625 273.312 \n",
"L 391.360625 7.2 \n",
"L 34.240625 7.2 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <defs>\n",
" <path id=\"mcdb36fe582\" d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#mcdb36fe582\" x=\"83.266739\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
" <!-- −4 -->\n",
" <g transform=\"translate(75.895645 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \n",
"L 4684 2272 \n",
"L 4684 1741 \n",
"L 678 1741 \n",
"L 678 2272 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
"L 825 1625 \n",
"L 2419 1625 \n",
"L 2419 4116 \n",
"z\n",
"M 2253 4666 \n",
"L 3047 4666 \n",
"L 3047 1625 \n",
"L 3713 1625 \n",
"L 3713 1100 \n",
"L 3047 1100 \n",
"L 3047 0 \n",
"L 2419 0 \n",
"L 2419 1100 \n",
"L 313 1100 \n",
"L 313 1709 \n",
"L 2253 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-2212\"/>\n",
" <use xlink:href=\"#DejaVuSans-34\" x=\"83.789062\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_2\">\n",
" <g>\n",
" <use xlink:href=\"#mcdb36fe582\" x=\"148.853512\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
" <!-- −2 -->\n",
" <g transform=\"translate(141.482418 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
"L 3431 531 \n",
"L 3431 0 \n",
"L 469 0 \n",
"L 469 531 \n",
"Q 828 903 1448 1529 \n",
"Q 2069 2156 2228 2338 \n",
"Q 2531 2678 2651 2914 \n",
"Q 2772 3150 2772 3378 \n",
"Q 2772 3750 2511 3984 \n",
"Q 2250 4219 1831 4219 \n",
"Q 1534 4219 1204 4116 \n",
"Q 875 4013 500 3803 \n",
"L 500 4441 \n",
"Q 881 4594 1212 4672 \n",
"Q 1544 4750 1819 4750 \n",
"Q 2544 4750 2975 4387 \n",
"Q 3406 4025 3406 3419 \n",
"Q 3406 3131 3298 2873 \n",
"Q 3191 2616 2906 2266 \n",
"Q 2828 2175 2409 1742 \n",
"Q 1991 1309 1228 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-2212\"/>\n",
" <use xlink:href=\"#DejaVuSans-32\" x=\"83.789062\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_3\">\n",
" <g id=\"line2d_3\">\n",
" <g>\n",
" <use xlink:href=\"#mcdb36fe582\" x=\"214.440285\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
" <!-- 0 -->\n",
" <g transform=\"translate(211.259035 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
"Q 1547 4250 1301 3770 \n",
"Q 1056 3291 1056 2328 \n",
"Q 1056 1369 1301 889 \n",
"Q 1547 409 2034 409 \n",
"Q 2525 409 2770 889 \n",
"Q 3016 1369 3016 2328 \n",
"Q 3016 3291 2770 3770 \n",
"Q 2525 4250 2034 4250 \n",
"z\n",
"M 2034 4750 \n",
"Q 2819 4750 3233 4129 \n",
"Q 3647 3509 3647 2328 \n",
"Q 3647 1150 3233 529 \n",
"Q 2819 -91 2034 -91 \n",
"Q 1250 -91 836 529 \n",
"Q 422 1150 422 2328 \n",
"Q 422 3509 836 4129 \n",
"Q 1250 4750 2034 4750 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_4\">\n",
" <g id=\"line2d_4\">\n",
" <g>\n",
" <use xlink:href=\"#mcdb36fe582\" x=\"280.027058\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_4\">\n",
" <!-- 2 -->\n",
" <g transform=\"translate(276.845808 287.910437) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_5\">\n",
" <g id=\"line2d_5\">\n",
" <g>\n",
" <use xlink:href=\"#mcdb36fe582\" x=\"345.613831\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_5\">\n",
" <!-- 4 -->\n",
" <g transform=\"translate(342.432581 287.910437) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_6\">\n",
" <!-- x -->\n",
" <g transform=\"translate(209.84125 301.588562) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-78\" d=\"M 3513 3500 \n",
"L 2247 1797 \n",
"L 3578 0 \n",
"L 2900 0 \n",
"L 1881 1375 \n",
"L 863 0 \n",
"L 184 0 \n",
"L 1544 1831 \n",
"L 300 3500 \n",
"L 978 3500 \n",
"L 1906 2253 \n",
"L 2834 3500 \n",
"L 3513 3500 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-78\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"matplotlib.axis_2\">\n",
" <g id=\"ytick_1\">\n",
" <g id=\"line2d_6\">\n",
" <defs>\n",
" <path id=\"mb64e12fe4a\" d=\"M 0 0 \n",
"L -3.5 0 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#mb64e12fe4a\" x=\"34.240625\" y=\"261.216\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_7\">\n",
" <!-- 0 -->\n",
" <g transform=\"translate(20.878125 265.015219) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_2\">\n",
" <g id=\"line2d_7\">\n",
" <g>\n",
" <use xlink:href=\"#mb64e12fe4a\" x=\"34.240625\" y=\"220.896\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_8\">\n",
" <!-- 1 -->\n",
" <g transform=\"translate(20.878125 224.695219) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
"L 1825 531 \n",
"L 1825 4091 \n",
"L 703 3866 \n",
"L 703 4441 \n",
"L 1819 4666 \n",
"L 2450 4666 \n",
"L 2450 531 \n",
"L 3481 531 \n",
"L 3481 0 \n",
"L 794 0 \n",
"L 794 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_3\">\n",
" <g id=\"line2d_8\">\n",
" <g>\n",
" <use xlink:href=\"#mb64e12fe4a\" x=\"34.240625\" y=\"180.576\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_9\">\n",
" <!-- 2 -->\n",
" <g transform=\"translate(20.878125 184.375219) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_4\">\n",
" <g id=\"line2d_9\">\n",
" <g>\n",
" <use xlink:href=\"#mb64e12fe4a\" x=\"34.240625\" y=\"140.256\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_10\">\n",
" <!-- 3 -->\n",
" <g transform=\"translate(20.878125 144.055219) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-33\" d=\"M 2597 2516 \n",
"Q 3050 2419 3304 2112 \n",
"Q 3559 1806 3559 1356 \n",
"Q 3559 666 3084 287 \n",
"Q 2609 -91 1734 -91 \n",
"Q 1441 -91 1130 -33 \n",
"Q 819 25 488 141 \n",
"L 488 750 \n",
"Q 750 597 1062 519 \n",
"Q 1375 441 1716 441 \n",
"Q 2309 441 2620 675 \n",
"Q 2931 909 2931 1356 \n",
"Q 2931 1769 2642 2001 \n",
"Q 2353 2234 1838 2234 \n",
"L 1294 2234 \n",
"L 1294 2753 \n",
"L 1863 2753 \n",
"Q 2328 2753 2575 2939 \n",
"Q 2822 3125 2822 3475 \n",
"Q 2822 3834 2567 4026 \n",
"Q 2313 4219 1838 4219 \n",
"Q 1578 4219 1281 4162 \n",
"Q 984 4106 628 3988 \n",
"L 628 4550 \n",
"Q 988 4650 1302 4700 \n",
"Q 1616 4750 1894 4750 \n",
"Q 2613 4750 3031 4423 \n",
"Q 3450 4097 3450 3541 \n",
"Q 3450 3153 3228 2886 \n",
"Q 3006 2619 2597 2516 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-33\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_5\">\n",
" <g id=\"line2d_10\">\n",
" <g>\n",
" <use xlink:href=\"#mb64e12fe4a\" x=\"34.240625\" y=\"99.936\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_11\">\n",
" <!-- 4 -->\n",
" <g transform=\"translate(20.878125 103.735219) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_6\">\n",
" <g id=\"line2d_11\">\n",
" <g>\n",
" <use xlink:href=\"#mb64e12fe4a\" x=\"34.240625\" y=\"59.616\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_12\">\n",
" <!-- 5 -->\n",
" <g transform=\"translate(20.878125 63.415219) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
"L 3169 4666 \n",
"L 3169 4134 \n",
"L 1269 4134 \n",
"L 1269 2991 \n",
"Q 1406 3038 1543 3061 \n",
"Q 1681 3084 1819 3084 \n",
"Q 2600 3084 3056 2656 \n",
"Q 3513 2228 3513 1497 \n",
"Q 3513 744 3044 326 \n",
"Q 2575 -91 1722 -91 \n",
"Q 1428 -91 1123 -41 \n",
"Q 819 9 494 109 \n",
"L 494 744 \n",
"Q 775 591 1075 516 \n",
"Q 1375 441 1709 441 \n",
"Q 2250 441 2565 725 \n",
"Q 2881 1009 2881 1497 \n",
"Q 2881 1984 2565 2268 \n",
"Q 2250 2553 1709 2553 \n",
"Q 1456 2553 1204 2497 \n",
"Q 953 2441 691 2322 \n",
"L 691 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_7\">\n",
" <g id=\"line2d_12\">\n",
" <g>\n",
" <use xlink:href=\"#mb64e12fe4a\" x=\"34.240625\" y=\"19.296\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_13\">\n",
" <!-- 6 -->\n",
" <g transform=\"translate(20.878125 23.095219) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-36\" d=\"M 2113 2584 \n",
"Q 1688 2584 1439 2293 \n",
"Q 1191 2003 1191 1497 \n",
"Q 1191 994 1439 701 \n",
"Q 1688 409 2113 409 \n",
"Q 2538 409 2786 701 \n",
"Q 3034 994 3034 1497 \n",
"Q 3034 2003 2786 2293 \n",
"Q 2538 2584 2113 2584 \n",
"z\n",
"M 3366 4563 \n",
"L 3366 3988 \n",
"Q 3128 4100 2886 4159 \n",
"Q 2644 4219 2406 4219 \n",
"Q 1781 4219 1451 3797 \n",
"Q 1122 3375 1075 2522 \n",
"Q 1259 2794 1537 2939 \n",
"Q 1816 3084 2150 3084 \n",
"Q 2853 3084 3261 2657 \n",
"Q 3669 2231 3669 1497 \n",
"Q 3669 778 3244 343 \n",
"Q 2819 -91 2113 -91 \n",
"Q 1303 -91 875 529 \n",
"Q 447 1150 447 2328 \n",
"Q 447 3434 972 4092 \n",
"Q 1497 4750 2381 4750 \n",
"Q 2619 4750 2861 4703 \n",
"Q 3103 4656 3366 4563 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-36\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_14\">\n",
" <!-- loss -->\n",
" <g transform=\"translate(14.798438 149.913812) rotate(-90) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
"L 1178 4863 \n",
"L 1178 0 \n",
"L 603 0 \n",
"L 603 4863 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
"Q 1497 3097 1228 2736 \n",
"Q 959 2375 959 1747 \n",
"Q 959 1119 1226 758 \n",
"Q 1494 397 1959 397 \n",
"Q 2419 397 2687 759 \n",
"Q 2956 1122 2956 1747 \n",
"Q 2956 2369 2687 2733 \n",
"Q 2419 3097 1959 3097 \n",
"z\n",
"M 1959 3584 \n",
"Q 2709 3584 3137 3096 \n",
"Q 3566 2609 3566 1747 \n",
"Q 3566 888 3137 398 \n",
"Q 2709 -91 1959 -91 \n",
"Q 1206 -91 779 398 \n",
"Q 353 888 353 1747 \n",
"Q 353 2609 779 3096 \n",
"Q 1206 3584 1959 3584 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
"L 2834 2853 \n",
"Q 2591 2978 2328 3040 \n",
"Q 2066 3103 1784 3103 \n",
"Q 1356 3103 1142 2972 \n",
"Q 928 2841 928 2578 \n",
"Q 928 2378 1081 2264 \n",
"Q 1234 2150 1697 2047 \n",
"L 1894 2003 \n",
"Q 2506 1872 2764 1633 \n",
"Q 3022 1394 3022 966 \n",
"Q 3022 478 2636 193 \n",
"Q 2250 -91 1575 -91 \n",
"Q 1294 -91 989 -36 \n",
"Q 684 19 347 128 \n",
"L 347 722 \n",
"Q 666 556 975 473 \n",
"Q 1284 391 1588 391 \n",
"Q 1994 391 2212 530 \n",
"Q 2431 669 2431 922 \n",
"Q 2431 1156 2273 1281 \n",
"Q 2116 1406 1581 1522 \n",
"L 1381 1569 \n",
"Q 847 1681 609 1914 \n",
"Q 372 2147 372 2553 \n",
"Q 372 3047 722 3315 \n",
"Q 1072 3584 1716 3584 \n",
"Q 2034 3584 2315 3537 \n",
"Q 2597 3491 2834 3397 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-6c\"/>\n",
" <use xlink:href=\"#DejaVuSans-6f\" x=\"27.783203\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"88.964844\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"141.064453\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_13\">\n",
" <path d=\"M 50.473352 19.296 \n",
"L 53.752688 23.327996 \n",
"L 57.032023 27.359992 \n",
"L 60.311374 31.392008 \n",
"L 63.59071 35.424004 \n",
"L 66.870046 39.456 \n",
"L 70.149381 43.487996 \n",
"L 73.428717 47.519992 \n",
"L 76.708068 51.552008 \n",
"L 79.987403 55.584004 \n",
"L 83.266739 59.616 \n",
"L 86.546074 63.647996 \n",
"L 89.825418 67.679992 \n",
"L 93.104761 71.712008 \n",
"L 96.384097 75.744004 \n",
"L 99.663432 79.776 \n",
"L 102.942768 83.807996 \n",
"L 106.222111 87.839992 \n",
"L 109.501454 91.872008 \n",
"L 112.78079 95.904004 \n",
"L 116.060125 99.936 \n",
"L 119.339469 103.968006 \n",
"L 122.618804 108.000002 \n",
"L 125.89814 112.031998 \n",
"L 129.177483 116.064004 \n",
"L 132.456819 120.096 \n",
"L 135.736162 124.128006 \n",
"L 139.015497 128.160002 \n",
"L 142.294833 132.191998 \n",
"L 145.574176 136.224004 \n",
"L 148.853512 140.256 \n",
"L 152.132855 144.288006 \n",
"L 155.412191 148.320002 \n",
"L 158.691526 152.351998 \n",
"L 161.97087 156.384004 \n",
"L 165.250205 160.416 \n",
"L 168.529548 164.448006 \n",
"L 171.808884 168.480002 \n",
"L 175.08822 172.511998 \n",
"L 178.367563 176.544004 \n",
"L 181.646898 180.576 \n",
"L 184.926234 184.607996 \n",
"L 188.205585 188.640012 \n",
"L 191.484921 192.672008 \n",
"L 194.764256 196.704004 \n",
"L 198.043592 200.736 \n",
"L 201.322927 204.767996 \n",
"L 204.602278 208.800012 \n",
"L 207.881614 212.832008 \n",
"L 211.160949 216.864004 \n",
"L 214.440285 220.896 \n",
"L 217.71962 224.927996 \n",
"L 220.998972 228.960012 \n",
"L 224.278307 232.992008 \n",
"L 227.557643 237.024004 \n",
"L 230.836978 241.056 \n",
"L 234.116314 245.087996 \n",
"L 237.395665 249.120012 \n",
"L 240.675 253.152008 \n",
"L 243.954336 257.184004 \n",
"L 247.233671 261.216 \n",
"L 250.513007 261.216 \n",
"L 253.792358 261.216 \n",
"L 257.071694 261.216 \n",
"L 260.351029 261.216 \n",
"L 263.630365 261.216 \n",
"L 266.9097 261.216 \n",
"L 270.189051 261.216 \n",
"L 273.468387 261.216 \n",
"L 276.747722 261.216 \n",
"L 280.027058 261.216 \n",
"L 283.306394 261.216 \n",
"L 286.585745 261.216 \n",
"L 289.86508 261.216 \n",
"L 293.144416 261.216 \n",
"L 296.423751 261.216 \n",
"L 299.703087 261.216 \n",
"L 302.982438 261.216 \n",
"L 306.261773 261.216 \n",
"L 309.541109 261.216 \n",
"L 312.820445 261.216 \n",
"L 316.099796 261.216 \n",
"L 319.379116 261.216 \n",
"L 322.658467 261.216 \n",
"L 325.937818 261.216 \n",
"L 329.217138 261.216 \n",
"L 332.496489 261.216 \n",
"L 335.775809 261.216 \n",
"L 339.05516 261.216 \n",
"L 342.334511 261.216 \n",
"L 345.613831 261.216 \n",
"L 348.893182 261.216 \n",
"L 352.172502 261.216 \n",
"L 355.451853 261.216 \n",
"L 358.731204 261.216 \n",
"L 362.010524 261.216 \n",
"L 365.289876 261.216 \n",
"L 368.569195 261.216 \n",
"L 371.848547 261.216 \n",
"L 375.127898 261.216 \n",
"\" clip-path=\"url(#p3766d70b85)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_3\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 34.240625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 391.360625 273.312 \n",
"L 391.360625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_5\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 391.360625 273.312 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_6\">\n",
" <path d=\"M 34.240625 7.2 \n",
"L 391.360625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"p3766d70b85\">\n",
" <rect x=\"34.240625\" y=\"7.2\" width=\"357.12\" height=\"266.112\"/>\n",
" </clipPath>\n",
" </defs>\n",
"</svg>\n"
],
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_classification_loss(gloss.HingeLoss())"
]
},
{
"cell_type": "markdown",
"id": "f7ad5c13",
"metadata": {},
"source": [
"#### [Logistic Loss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.LogisticLoss)\n",
"\n",
"The Logistic Loss function computes the performance of binary classification models.\n",
"$$\n",
"L = \\sum_i \\log(1 + \\exp(- {pred}_i \\cdot {label}_i))\n",
"$$\n",
"The log loss decreases the closer the prediction is to the actual label. It is sensitive to outliers, because incorrectly classified points are penalized more."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "726100e8",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_191919/841998161.py:2: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`\n",
" display.set_matplotlib_formats('svg')\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"398.560625pt\" height=\"310.86825pt\" viewBox=\"0 0 398.560625 310.86825\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
" <metadata>\n",
" <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
" <cc:Work>\n",
" <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
" <dc:date>2023-01-05T04:48:36.822569</dc:date>\n",
" <dc:format>image/svg+xml</dc:format>\n",
" <dc:creator>\n",
" <cc:Agent>\n",
" <dc:title>Matplotlib v3.6.2, https://matplotlib.org/</dc:title>\n",
" </cc:Agent>\n",
" </dc:creator>\n",
" </cc:Work>\n",
" </rdf:RDF>\n",
" </metadata>\n",
" <defs>\n",
" <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
" </defs>\n",
" <g id=\"figure_1\">\n",
" <g id=\"patch_1\">\n",
" <path d=\"M 0 310.86825 \n",
"L 398.560625 310.86825 \n",
"L 398.560625 0 \n",
"L 0 0 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"axes_1\">\n",
" <g id=\"patch_2\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 391.360625 273.312 \n",
"L 391.360625 7.2 \n",
"L 34.240625 7.2 \n",
"z\n",
"\" style=\"fill: #ffffff\"/>\n",
" </g>\n",
" <g id=\"matplotlib.axis_1\">\n",
" <g id=\"xtick_1\">\n",
" <g id=\"line2d_1\">\n",
" <defs>\n",
" <path id=\"m9b3bdb791d\" d=\"M 0 0 \n",
"L 0 3.5 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#m9b3bdb791d\" x=\"83.266739\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_1\">\n",
" <!-- −4 -->\n",
" <g transform=\"translate(75.895645 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-2212\" d=\"M 678 2272 \n",
"L 4684 2272 \n",
"L 4684 1741 \n",
"L 678 1741 \n",
"L 678 2272 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
"L 825 1625 \n",
"L 2419 1625 \n",
"L 2419 4116 \n",
"z\n",
"M 2253 4666 \n",
"L 3047 4666 \n",
"L 3047 1625 \n",
"L 3713 1625 \n",
"L 3713 1100 \n",
"L 3047 1100 \n",
"L 3047 0 \n",
"L 2419 0 \n",
"L 2419 1100 \n",
"L 313 1100 \n",
"L 313 1709 \n",
"L 2253 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-2212\"/>\n",
" <use xlink:href=\"#DejaVuSans-34\" x=\"83.789062\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_2\">\n",
" <g id=\"line2d_2\">\n",
" <g>\n",
" <use xlink:href=\"#m9b3bdb791d\" x=\"148.853512\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_2\">\n",
" <!-- −2 -->\n",
" <g transform=\"translate(141.482418 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
"L 3431 531 \n",
"L 3431 0 \n",
"L 469 0 \n",
"L 469 531 \n",
"Q 828 903 1448 1529 \n",
"Q 2069 2156 2228 2338 \n",
"Q 2531 2678 2651 2914 \n",
"Q 2772 3150 2772 3378 \n",
"Q 2772 3750 2511 3984 \n",
"Q 2250 4219 1831 4219 \n",
"Q 1534 4219 1204 4116 \n",
"Q 875 4013 500 3803 \n",
"L 500 4441 \n",
"Q 881 4594 1212 4672 \n",
"Q 1544 4750 1819 4750 \n",
"Q 2544 4750 2975 4387 \n",
"Q 3406 4025 3406 3419 \n",
"Q 3406 3131 3298 2873 \n",
"Q 3191 2616 2906 2266 \n",
"Q 2828 2175 2409 1742 \n",
"Q 1991 1309 1228 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-2212\"/>\n",
" <use xlink:href=\"#DejaVuSans-32\" x=\"83.789062\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_3\">\n",
" <g id=\"line2d_3\">\n",
" <g>\n",
" <use xlink:href=\"#m9b3bdb791d\" x=\"214.440285\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_3\">\n",
" <!-- 0 -->\n",
" <g transform=\"translate(211.259035 287.910437) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
"Q 1547 4250 1301 3770 \n",
"Q 1056 3291 1056 2328 \n",
"Q 1056 1369 1301 889 \n",
"Q 1547 409 2034 409 \n",
"Q 2525 409 2770 889 \n",
"Q 3016 1369 3016 2328 \n",
"Q 3016 3291 2770 3770 \n",
"Q 2525 4250 2034 4250 \n",
"z\n",
"M 2034 4750 \n",
"Q 2819 4750 3233 4129 \n",
"Q 3647 3509 3647 2328 \n",
"Q 3647 1150 3233 529 \n",
"Q 2819 -91 2034 -91 \n",
"Q 1250 -91 836 529 \n",
"Q 422 1150 422 2328 \n",
"Q 422 3509 836 4129 \n",
"Q 1250 4750 2034 4750 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_4\">\n",
" <g id=\"line2d_4\">\n",
" <g>\n",
" <use xlink:href=\"#m9b3bdb791d\" x=\"280.027058\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_4\">\n",
" <!-- 2 -->\n",
" <g transform=\"translate(276.845808 287.910437) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"xtick_5\">\n",
" <g id=\"line2d_5\">\n",
" <g>\n",
" <use xlink:href=\"#m9b3bdb791d\" x=\"345.613831\" y=\"273.312\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_5\">\n",
" <!-- 4 -->\n",
" <g transform=\"translate(342.432581 287.910437) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_6\">\n",
" <!-- x -->\n",
" <g transform=\"translate(209.84125 301.588562) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-78\" d=\"M 3513 3500 \n",
"L 2247 1797 \n",
"L 3578 0 \n",
"L 2900 0 \n",
"L 1881 1375 \n",
"L 863 0 \n",
"L 184 0 \n",
"L 1544 1831 \n",
"L 300 3500 \n",
"L 978 3500 \n",
"L 1906 2253 \n",
"L 2834 3500 \n",
"L 3513 3500 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-78\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"matplotlib.axis_2\">\n",
" <g id=\"ytick_1\">\n",
" <g id=\"line2d_6\">\n",
" <defs>\n",
" <path id=\"mbd90fd5020\" d=\"M 0 0 \n",
"L -3.5 0 \n",
"\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </defs>\n",
" <g>\n",
" <use xlink:href=\"#mbd90fd5020\" x=\"34.240625\" y=\"261.575011\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_7\">\n",
" <!-- 0 -->\n",
" <g transform=\"translate(20.878125 265.37423) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-30\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_2\">\n",
" <g id=\"line2d_7\">\n",
" <g>\n",
" <use xlink:href=\"#mbd90fd5020\" x=\"34.240625\" y=\"213.1842\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_8\">\n",
" <!-- 1 -->\n",
" <g transform=\"translate(20.878125 216.983419) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
"L 1825 531 \n",
"L 1825 4091 \n",
"L 703 3866 \n",
"L 703 4441 \n",
"L 1819 4666 \n",
"L 2450 4666 \n",
"L 2450 531 \n",
"L 3481 531 \n",
"L 3481 0 \n",
"L 794 0 \n",
"L 794 531 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-31\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_3\">\n",
" <g id=\"line2d_8\">\n",
" <g>\n",
" <use xlink:href=\"#mbd90fd5020\" x=\"34.240625\" y=\"164.79339\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_9\">\n",
" <!-- 2 -->\n",
" <g transform=\"translate(20.878125 168.592609) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-32\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_4\">\n",
" <g id=\"line2d_9\">\n",
" <g>\n",
" <use xlink:href=\"#mbd90fd5020\" x=\"34.240625\" y=\"116.40258\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_10\">\n",
" <!-- 3 -->\n",
" <g transform=\"translate(20.878125 120.201798) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-33\" d=\"M 2597 2516 \n",
"Q 3050 2419 3304 2112 \n",
"Q 3559 1806 3559 1356 \n",
"Q 3559 666 3084 287 \n",
"Q 2609 -91 1734 -91 \n",
"Q 1441 -91 1130 -33 \n",
"Q 819 25 488 141 \n",
"L 488 750 \n",
"Q 750 597 1062 519 \n",
"Q 1375 441 1716 441 \n",
"Q 2309 441 2620 675 \n",
"Q 2931 909 2931 1356 \n",
"Q 2931 1769 2642 2001 \n",
"Q 2353 2234 1838 2234 \n",
"L 1294 2234 \n",
"L 1294 2753 \n",
"L 1863 2753 \n",
"Q 2328 2753 2575 2939 \n",
"Q 2822 3125 2822 3475 \n",
"Q 2822 3834 2567 4026 \n",
"Q 2313 4219 1838 4219 \n",
"Q 1578 4219 1281 4162 \n",
"Q 984 4106 628 3988 \n",
"L 628 4550 \n",
"Q 988 4650 1302 4700 \n",
"Q 1616 4750 1894 4750 \n",
"Q 2613 4750 3031 4423 \n",
"Q 3450 4097 3450 3541 \n",
"Q 3450 3153 3228 2886 \n",
"Q 3006 2619 2597 2516 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-33\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_5\">\n",
" <g id=\"line2d_10\">\n",
" <g>\n",
" <use xlink:href=\"#mbd90fd5020\" x=\"34.240625\" y=\"68.011769\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_11\">\n",
" <!-- 4 -->\n",
" <g transform=\"translate(20.878125 71.810988) scale(0.1 -0.1)\">\n",
" <use xlink:href=\"#DejaVuSans-34\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"ytick_6\">\n",
" <g id=\"line2d_11\">\n",
" <g>\n",
" <use xlink:href=\"#mbd90fd5020\" x=\"34.240625\" y=\"19.620959\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_12\">\n",
" <!-- 5 -->\n",
" <g transform=\"translate(20.878125 23.420177) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
"L 3169 4666 \n",
"L 3169 4134 \n",
"L 1269 4134 \n",
"L 1269 2991 \n",
"Q 1406 3038 1543 3061 \n",
"Q 1681 3084 1819 3084 \n",
"Q 2600 3084 3056 2656 \n",
"Q 3513 2228 3513 1497 \n",
"Q 3513 744 3044 326 \n",
"Q 2575 -91 1722 -91 \n",
"Q 1428 -91 1123 -41 \n",
"Q 819 9 494 109 \n",
"L 494 744 \n",
"Q 775 591 1075 516 \n",
"Q 1375 441 1709 441 \n",
"Q 2250 441 2565 725 \n",
"Q 2881 1009 2881 1497 \n",
"Q 2881 1984 2565 2268 \n",
"Q 2250 2553 1709 2553 \n",
"Q 1456 2553 1204 2497 \n",
"Q 953 2441 691 2322 \n",
"L 691 4666 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-35\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"text_13\">\n",
" <!-- loss -->\n",
" <g transform=\"translate(14.798438 149.913812) rotate(-90) scale(0.1 -0.1)\">\n",
" <defs>\n",
" <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
"L 1178 4863 \n",
"L 1178 0 \n",
"L 603 0 \n",
"L 603 4863 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
"Q 1497 3097 1228 2736 \n",
"Q 959 2375 959 1747 \n",
"Q 959 1119 1226 758 \n",
"Q 1494 397 1959 397 \n",
"Q 2419 397 2687 759 \n",
"Q 2956 1122 2956 1747 \n",
"Q 2956 2369 2687 2733 \n",
"Q 2419 3097 1959 3097 \n",
"z\n",
"M 1959 3584 \n",
"Q 2709 3584 3137 3096 \n",
"Q 3566 2609 3566 1747 \n",
"Q 3566 888 3137 398 \n",
"Q 2709 -91 1959 -91 \n",
"Q 1206 -91 779 398 \n",
"Q 353 888 353 1747 \n",
"Q 353 2609 779 3096 \n",
"Q 1206 3584 1959 3584 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
"L 2834 2853 \n",
"Q 2591 2978 2328 3040 \n",
"Q 2066 3103 1784 3103 \n",
"Q 1356 3103 1142 2972 \n",
"Q 928 2841 928 2578 \n",
"Q 928 2378 1081 2264 \n",
"Q 1234 2150 1697 2047 \n",
"L 1894 2003 \n",
"Q 2506 1872 2764 1633 \n",
"Q 3022 1394 3022 966 \n",
"Q 3022 478 2636 193 \n",
"Q 2250 -91 1575 -91 \n",
"Q 1294 -91 989 -36 \n",
"Q 684 19 347 128 \n",
"L 347 722 \n",
"Q 666 556 975 473 \n",
"Q 1284 391 1588 391 \n",
"Q 1994 391 2212 530 \n",
"Q 2431 669 2431 922 \n",
"Q 2431 1156 2273 1281 \n",
"Q 2116 1406 1581 1522 \n",
"L 1381 1569 \n",
"Q 847 1681 609 1914 \n",
"Q 372 2147 372 2553 \n",
"Q 372 3047 722 3315 \n",
"Q 1072 3584 1716 3584 \n",
"Q 2034 3584 2315 3537 \n",
"Q 2597 3491 2834 3397 \n",
"z\n",
"\" transform=\"scale(0.015625)\"/>\n",
" </defs>\n",
" <use xlink:href=\"#DejaVuSans-6c\"/>\n",
" <use xlink:href=\"#DejaVuSans-6f\" x=\"27.783203\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"88.964844\"/>\n",
" <use xlink:href=\"#DejaVuSans-73\" x=\"141.064453\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <g id=\"line2d_12\">\n",
" <path d=\"M 50.473352 19.296 \n",
"L 53.752688 24.101018 \n",
"L 57.032023 28.902506 \n",
"L 60.311374 33.700072 \n",
"L 63.59071 38.493299 \n",
"L 66.870046 43.28175 \n",
"L 70.149381 48.06494 \n",
"L 73.428717 52.842338 \n",
"L 76.708068 57.613345 \n",
"L 79.987403 62.37729 \n",
"L 83.266739 67.133483 \n",
"L 86.546074 71.881104 \n",
"L 89.825418 76.619321 \n",
"L 93.104761 81.347178 \n",
"L 96.384097 86.063625 \n",
"L 99.663432 90.767531 \n",
"L 102.942768 95.45765 \n",
"L 106.222111 100.132655 \n",
"L 109.501454 104.791057 \n",
"L 112.78079 109.431208 \n",
"L 116.060125 114.0514 \n",
"L 119.339469 118.649716 \n",
"L 122.618804 123.224093 \n",
"L 125.89814 127.772315 \n",
"L 129.177483 132.291959 \n",
"L 132.456819 136.780441 \n",
"L 135.736162 141.235004 \n",
"L 139.015497 145.652624 \n",
"L 142.294833 150.030163 \n",
"L 145.574176 154.36423 \n",
"L 148.853512 158.651237 \n",
"L 152.132855 162.887433 \n",
"L 155.412191 167.068845 \n",
"L 158.691526 171.19133 \n",
"L 161.97087 175.250614 \n",
"L 165.250205 179.242244 \n",
"L 168.529548 183.161702 \n",
"L 171.808884 187.004362 \n",
"L 175.08822 190.765586 \n",
"L 178.367563 194.440731 \n",
"L 181.646898 198.025211 \n",
"L 184.926234 201.514565 \n",
"L 188.205585 204.904509 \n",
"L 191.484921 208.190952 \n",
"L 194.764256 211.370132 \n",
"L 198.043592 214.438636 \n",
"L 201.322927 217.39346 \n",
"L 204.602278 220.232075 \n",
"L 207.881614 222.952431 \n",
"L 211.160949 225.553055 \n",
"L 214.440285 228.033057 \n",
"L 217.71962 230.392131 \n",
"L 220.998972 232.630592 \n",
"L 224.278307 234.749316 \n",
"L 227.557643 236.749789 \n",
"L 230.836978 238.634041 \n",
"L 234.116314 240.404614 \n",
"L 237.395665 242.064515 \n",
"L 240.675 243.617152 \n",
"L 243.954336 245.066299 \n",
"L 247.233671 246.416024 \n",
"L 250.513007 247.670621 \n",
"L 253.792358 248.834561 \n",
"L 257.071694 249.912418 \n",
"L 260.351029 250.908835 \n",
"L 263.630365 251.828459 \n",
"L 266.9097 252.675904 \n",
"L 270.189051 253.455711 \n",
"L 273.468387 254.172302 \n",
"L 276.747722 254.829973 \n",
"L 280.027058 255.432862 \n",
"L 283.306394 255.984927 \n",
"L 286.585745 256.489945 \n",
"L 289.86508 256.951489 \n",
"L 293.144416 257.37294 \n",
"L 296.423751 257.757473 \n",
"L 299.703087 258.108066 \n",
"L 302.982438 258.427501 \n",
"L 306.261773 258.718365 \n",
"L 309.541109 258.983065 \n",
"L 312.820445 259.22383 \n",
"L 316.099796 259.442721 \n",
"L 319.379116 259.641636 \n",
"L 322.658467 259.822332 \n",
"L 325.937818 259.986415 \n",
"L 329.217138 260.135364 \n",
"L 332.496489 260.270536 \n",
"L 335.775809 260.393169 \n",
"L 339.05516 260.504402 \n",
"L 342.334511 260.605271 \n",
"L 345.613831 260.696721 \n",
"L 348.893182 260.779619 \n",
"L 352.172502 260.85475 \n",
"L 355.451853 260.922832 \n",
"L 358.731204 260.984518 \n",
"L 362.010524 261.040402 \n",
"L 365.289876 261.091023 \n",
"L 368.569195 261.136873 \n",
"L 371.848547 261.178397 \n",
"L 375.127898 261.216 \n",
"\" clip-path=\"url(#pb6413782e1)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_3\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 34.240625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_4\">\n",
" <path d=\"M 391.360625 273.312 \n",
"L 391.360625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_5\">\n",
" <path d=\"M 34.240625 273.312 \n",
"L 391.360625 273.312 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" <g id=\"patch_6\">\n",
" <path d=\"M 34.240625 7.2 \n",
"L 391.360625 7.2 \n",
"\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
" </g>\n",
" </g>\n",
" </g>\n",
" <defs>\n",
" <clipPath id=\"pb6413782e1\">\n",
" <rect x=\"34.240625\" y=\"7.2\" width=\"357.12\" height=\"266.112\"/>\n",
" </clipPath>\n",
" </defs>\n",
"</svg>\n"
],
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_classification_loss(gloss.LogisticLoss())"
]
},
{
"cell_type": "markdown",
"id": "7c73c916",
"metadata": {},
"source": [
"#### [Kullback-Leibler Divergence Loss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.KLDivLoss)\n",
"\n",
"The Kullback-Leibler divergence loss measures the divergence between two probability distributions by calculating the difference between cross entropy and entropy. It takes as input the probability of predicted label and the probability of true label.\n",
"\n",
"$$\n",
"L = \\sum_i {label}_i * \\big[\\log({label}_i) - {pred}_i\\big]\n",
"$$\n",
"\n",
"The loss is large, if the predicted probability distribution is far from the ground truth probability distribution. KL divergence is an asymmetric measure. KL divergence loss can be used in Variational Autoencoders (VAEs), and reinforcement learning policy networks such as Trust Region Policy Optimization (TRPO)\n",
"\n",
"\n",
"For instance, in the following example we get a KL divergence of 0.02. We set ```from_logits=False```, so the loss functions will apply ```log_softmax``` on the network output, before computing the KL divergence."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "94c457fe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"output.softmax(): [[0.19999998807907104, 0.5, 0.19999998807907104, 0.09999999403953552]]\n",
"loss (kl divergence): [0.025424206629395485]\n"
]
}
],
"source": [
"output = mx.np.array([[0.39056206, 1.3068528, 0.39056206, -0.30258512]])\n",
"print('output.softmax(): {}'.format(npx.softmax(output).asnumpy().tolist()))\n",
"target_dist = mx.np.array([[0.3, 0.4, 0.1, 0.2]])\n",
"loss_fn = gloss.KLDivLoss(from_logits=False)\n",
"loss = loss_fn(output, target_dist)\n",
"print('loss (kl divergence): {}'.format(loss.asnumpy().tolist()))"
]
},
{
"cell_type": "markdown",
"id": "1bb69431",
"metadata": {},
"source": [
"#### [Triplet Loss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.TripletLoss)\n",
"\n",
"Triplet loss takes three input arrays and measures the relative similarity. It takes a positive and negative input and the anchor.\n",
"\n",
"$$\n",
"L = \\sum_i \\max(\\Vert {pos_i}_i - {pred} \\Vert_2^2 -\n",
" \\Vert {neg_i}_i - {pred} \\Vert_2^2 + {margin}, 0)\n",
"$$\n",
"\n",
"The loss function minimizes the distance between similar inputs and maximizes the distance between dissimilar ones.\n",
"In the case of learning embeddings for images of characters, the network may get as input the following 3 images:\n",
"\n",
"![triplet_loss](/_static/triplet_loss.png)\n",
"\n",
"The network would learn to minimize the distance between the two `A`'s and maximize the distance between `A` and `Z`.\n",
"\n",
"#### [CTC Loss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.CTCLoss)\n",
"\n",
"CTC Loss is the [connectionist temporal classification loss](https://distill.pub/2017/ctc/) . It is used to train recurrent neural networks with variable time dimension. It learns the alignment and labelling of input sequences. It takes a sequence as input and gives probabilities for each timestep. For instance, in the following image the word is not well aligned with the 5 timesteps because of the different sizes of characters. CTC Loss finds for each timestep the highest probability e.g. `t1` presents with high probability a `C`. It combines the highest probapilities and returns the best path decoding.\n",
"\n",
"![ctc_loss](/_static/ctc_loss.png)\n",
"\n",
"#### [Cosine Embedding Loss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.CosineEmbeddingLoss)\n",
"The cosine embedding loss computes the cosine distance between two input vectors.\n",
"\n",
"$$\n",
"\\begin{split}L = \\sum_i \\begin{cases} 1 - {cos\\_sim({input1}_i, {input2}_i)} & \\text{ if } {label}_i = 1\\\\\n",
" {cos\\_sim({input1}_i, {input2}_i)} & \\text{ if } {label}_i = -1 \\end{cases}\\\\\n",
"cos\\_sim(input1, input2) = \\frac{{input1}_i.{input2}_i}{||{input1}_i||.||{input2}_i||}\\end{split}\n",
"$$\n",
"\n",
"Cosine distance measures the similarity between two arrays given a label and is typically used for learning nonlinear embeddings.\n",
"For instance, in the following code example we measure the similarity between the input vectors `x` and `y`. Since they are the same the label equals `1`. The loss function returns $$ \\sum_i 1 - {cos\\_sim({input1}_i, {input2}_i)} $$ which is equal `0`."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "a1d77c4d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.]\n"
]
}
],
"source": [
"x = mx.np.array([1,0,1,0,1,0])\n",
"y = mx.np.array([1,0,1,0,1,0])\n",
"label = mx.np.array([1])\n",
"loss = gloss.CosineEmbeddingLoss()\n",
"print(loss(x,y,label))"
]
},
{
"cell_type": "markdown",
"id": "9aa456f4",
"metadata": {},
"source": [
"Now let's make `y` the opposite of `x`, so we set the label `-1` and the function will return $$ \\sum_i cos\\_sim(input1, input2) $$"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "05f2ce95",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.]\n"
]
}
],
"source": [
"x = mx.np.array([1,0,1,0,1,0])\n",
"y = mx.np.array([0,1,0,1,0,1])\n",
"label = mx.np.array([-1])\n",
"loss = gloss.CosineEmbeddingLoss()\n",
"print(loss(x,y,label))"
]
},
{
"cell_type": "markdown",
"id": "0345267e",
"metadata": {},
"source": [
"#### [PoissonNLLLoss](../../../../api/gluon/loss/index.rst#mxnet.gluon.loss.PoissonNLLLoss)\n",
"Poisson distribution is widely used for modelling count data. It is defined as:\n",
"\n",
"$$\n",
"f(x) = \\frac{\\mu ^ {\\kern 0.08 em x} e ^ {-\\mu}} {x!} \\qquad \\qquad x = 0,1,2 , \\ldots \\,.\n",
"$$\n",
"\n",
"\n",
"For instance, the count of cars in road traffic approximately follows a Poisson distribution. Using an ordinary least squares model for Poisson distributed data would not work well because of two reasons:\n",
" - count data cannot be negative\n",
" - variance may not be constant\n",
"\n",
"Instead we can use a Poisson regression model, also known as log-linear model. Thereby the Poisson incident rate $$\\mu$$ is\n",
"modelled by a linear combination of unknown parameters.\n",
"We can then use the PoissonNLLLoss which calculates the negative log likelihood for a target that follows a Poisson distribution.\n",
"\n",
"$$ L = \\text{pred} - \\text{target} * \\log(\\text{pred}) +\\log(\\text{target!}) $$\n",
"\n",
"## Advanced: Weighted Loss\n",
"\n",
"Some examples in a batch may be more important than others. We can apply weights to individual examples during the forward pass of the loss function using the `sample_weight` argument. All examples are weighted equally by default."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "1dbf2b6b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.5, 1. ])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.ones((2,))\n",
"y = np.ones((2,)) * 2\n",
"loss = gloss.L2Loss()\n",
"loss(x, y, np.array([1, 2]))"
]
},
{
"cell_type": "markdown",
"id": "e402d55f",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"In this tutorial we saw an example of how to evaluate model performance using loss functions (during the forward pass). Crucially, we then saw how calculate parameter gradients (using `backward`) which would minimise this loss. You should now have a better understanding of when to apply different loss functions, especially for regression vs classification tasks.\n",
"\n",
"## Recommended Next Steps\n",
"\n",
"In addition to loss functions, which are used for explicit optimization, you might want to look at metrics that give useful evaluation feedback even if they're not explicitly optimized for in the same way as the loss. You might also want to learn more about the mechanics of the backpropagation stage in the autograd tutorial."
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}