blob: 04e9329840dd25c22f13b70b87c880d38df6490a [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "e98ccb6a",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"# Kullback-Leibler (KL) Divergence\n",
"\n",
"Kullback-Leibler (KL) Divergence is a measure of how one probability distribution is different from a second, reference probability distribution. Smaller KL Divergence values indicate more similar distributions and, since this loss function is differentiable, we can use gradient descent to minimize the KL divergence between network outputs and some target distribution. As an example, this can be used in Variational Autoencoders (VAEs), and reinforcement learning policy networks such as [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477).\n",
"\n",
"In MXNet Gluon, we can use [KLDivLoss](/api/python/docs/api/gluon/_autogen/mxnet.gluon.loss.KLDivLoss.html) to compare categorical distributions. One important thing to note is that the KL Divergence is an asymmetric measure (i.e. `KL(P,Q) != KL(Q,P)`): order matters and we should compare our predicted distribution with our target distribution in that order. Another thing to note is that there are two ways to use [KLDivLoss](/api/python/docs/api/gluon/_autogen/mxnet.gluon.loss.KLDivLoss.html) that depend on how we set `from_logits` (which has a default value of true). \n",
"\n",
"As an example, let's compare a few categorical distributions (`dist_1`, `dist_2` and `dist_3`), each with 4 categories."
]
},
{
"cell_type": "markdown",
"id": "d14d701b",
"metadata": {},
"source": [
"```\n",
"from matplotlib import pyplot as plt\n",
"import mxnet as mx\n",
"import numpy as np\n",
"\n",
"idx = np.array([1, 2, 3, 4])\n",
"dist_1 = np.array([0.2, 0.5, 0.2, 0.1])\n",
"dist_2 = np.array([0.3, 0.4, 0.1, 0.2])\n",
"dist_3 = np.array([0.1, 0.1, 0.1, 0.7])\n",
"\n",
"plt.figure(figsize=(10,5))\n",
"plt.subplot(1,2,1)\n",
"plt.ylim(top=1)\n",
"plt.bar(idx, dist_1, alpha=0.5, color='black')\n",
"plt.bar(idx, dist_2, alpha=0.5, color='aqua')\n",
"plt.title('Distributions 1 & 2')\n",
"plt.subplot(1,2,2)\n",
"plt.ylim(top=1)\n",
"plt.bar(idx, dist_1, alpha=0.5, color='black')\n",
"plt.bar(idx, dist_3, alpha=0.5, color='aqua')\n",
"plt.title('Distributions 1 & 3')\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "2b7574ab",
"metadata": {},
"source": [
"We can see visually that distributions 1 and 2 are more similar than distributions 1 and 3. We'll confirm this result using [KLDivLoss](/api/python/docs/api/gluon/_autogen/mxnet.gluon.loss.KLDivLoss.html). When using [KLDivLoss](/api/python/docs/api/gluon/_autogen/mxnet.gluon.loss.KLDivLoss.html) with the default `from_logits=True` we need:\n",
"\n",
"1. our predictions to be parameters of a logged probability distribution.\n",
"2. our targets to be parameters of a probability distribution (i.e. not logged).\n",
"\n",
"We often apply a [softmax](/api/python/docs/api/ndarray/_autogen/mxnet.ndarray.softmax.html) operation to the output of our network to get a distribution, but this can have a numerically unstable gradient calculation. As as stable alternative, we use [log_softmax](/api/python/docs/api/ndarray/_autogen/mxnet.ndarray.log_softmax.html) and so this is what is expected by [KLDivLoss](/api/python/docs/api/gluon/_autogen/mxnet.gluon.loss.KLDivLoss.html) when `from_logits=True`. We also usually work with batches of predictions, so the predictions and targets need to have a batch dimension (the first axis by default).\n",
"\n",
"Since we're already working with distributions in this example, we don't need to apply the softmax and only need to apply [log](/api/python/docs/api/ndarray/_autogen/mxnet.ndarray.log.html). And we'll create batch dimensions even though we're working with single distributions."
]
},
{
"cell_type": "markdown",
"id": "6e450009",
"metadata": {},
"source": [
"```\n",
"def kl_divergence(dist_a, dist_b):\n",
" # add batch dimension\n",
" pred_batch = mx.nd.array(dist_a).expand_dims(0)\n",
" target_batch = mx.nd.array(dist_b).expand_dims(0)\n",
" # log the distribution\n",
" pred_batch = pred_batch.log()\n",
" # create loss (assuming we have a logged prediction distribution)\n",
" loss_fn = mx.gluon.loss.KLDivLoss(from_logits=True)\n",
" divergence = loss_fn(pred_batch, target_batch)\n",
" return divergence.asscalar()\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "5f4f71b1",
"metadata": {},
"source": [
"```\n",
"print(\"Distribution 1 compared with Distribution 2: {}\".format(\n",
" kl_divergence(dist_1, dist_2)))\n",
"print(\"Distribution 1 compared with Distribution 3: {}\".format(\n",
" kl_divergence(dist_1, dist_3)))\n",
"print(\"Distribution 1 compared with Distribution 1: {}\".format(\n",
" kl_divergence(dist_1, dist_1)))\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "70e07373",
"metadata": {},
"source": [
"As expected we see a smaller KL Divergence for distributions 1 & 2 than 1 & 3. And we also see the KL Divergence of a distribution with itself is 0.\n",
"\n",
"#### `from_logits=False`\n",
"\n",
"Alternatively, instead of manually applying the [log_softmax](/api/python/docs/api/ndarray/_autogen/mxnet.ndarray.log_softmax.html) to our network outputs, we can leave that to the loss function. When setting `from_logits=False` on [KLDivLoss](/api/python/docs/api/gluon/_autogen/mxnet.gluon.loss.KLDivLoss.html), the [log_softmax](/api/python/docs/api/ndarray/_autogen/mxnet.ndarray.log_softmax.html) is applied to the first argument passed to `loss_fn`. As an example, let's assume our network outputs us the values below (favorably chosen so that when we [softmax](/api/python/docs/api/ndarray/_autogen/mxnet.ndarray.softmax.html) these values we get the same distribution parameters as `dist_1`)."
]
},
{
"cell_type": "markdown",
"id": "a2738a99",
"metadata": {},
"source": [
"```\n",
"output = mx.nd.array([0.39056206, 1.3068528, 0.39056206, -0.30258512])\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "8778f408",
"metadata": {},
"source": [
"We can pass this to our [KLDivLoss](/api/python/docs/api/gluon/_autogen/mxnet.gluon.loss.KLDivLoss.html) loss function (with `from_logits=False`) and get the same KL Divergence between `dist_1` and `dist_2` as before, because the [log_softmax](/api/python/docs/api/ndarray/_autogen/mxnet.ndarray.log_softmax.html) is applied within the loss function."
]
},
{
"cell_type": "markdown",
"id": "1b7fecd5",
"metadata": {},
"source": [
"```\n",
"def kl_divergence_not_from_logits(dist_a, dist_b):\n",
" # add batch dimension\n",
" pred_batch = mx.nd.array(dist_a).expand_dims(0)\n",
" target_batch = mx.nd.array(dist_b).expand_dims(0)\n",
" # create loss (assuming we have a logged prediction distribution)\n",
" loss_fn = mx.gluon.loss.KLDivLoss(from_logits=False)\n",
" divergence = loss_fn(pred_batch, target_batch)\n",
" return divergence.asscalar()\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "2a415e2f",
"metadata": {},
"source": [
"```\n",
"print(\"Distribution 1 compared with Distribution 2: {}\".format(\n",
" kl_divergence_not_from_logits(output, dist_2)))\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "092e2596",
"metadata": {},
"source": [
"### Advanced: Common Support\n",
"\n",
"Occasionally, you might have issues with [KLDivLoss](/api/python/docs/api/gluon/_autogen/mxnet.gluon.loss.KLDivLoss.html). One common issue arises when the support of the distributions being compared are not the same. 'Support' here is referring to the values of the distribution which have a non-zero probability. Conveniently, all our examples above had the same support, but we might have a case where some categories have a probability of 0."
]
},
{
"cell_type": "markdown",
"id": "4b0f83ad",
"metadata": {},
"source": [
"```\n",
"dist_4 = np.array([0, 0.9, 0, 0.1])\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "4ef730ba",
"metadata": {},
"source": [
"```\n",
"print(\"Distribution 4 compared with Distribution 1: {}\".format(\n",
" kl_divergence(dist_4, dist_1)))\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "215bc4e2",
"metadata": {},
"source": [
"We can see that the result is `nan`, which will obviously cause issues when calculating the gradient. One option is to add a small value `epsilon` to all of the probabilities, and this is already done for the target distribution (using the value of 1e-12).\n",
"\n",
"### Advanced: Aggregation\n",
"\n",
"One minor difference between the true definition of KL Divergence and the result from [KLDivLoss](/api/python/docs/api/gluon/_autogen/mxnet.gluon.loss.KLDivLoss.html) is how the aggregation of category contributions is performed. Although the true definition sums up these contributions, the default behaviour in MXNet Gluon is to average terms along the batch dimension. As a result, the [KLDivLoss](/api/python/docs/api/gluon/_autogen/mxnet.gluon.loss.KLDivLoss.html) output will be smaller than the true definition by a factor of the number of categories."
]
},
{
"cell_type": "markdown",
"id": "44e51392",
"metadata": {},
"source": [
"```\n",
"true_divergence = (dist_2*(np.log(dist_2)-np.log(dist_1))).sum()\n",
"print('true_divergence: {}'.format(true_divergence))\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "909b93f1",
"metadata": {},
"source": [
"```\n",
"num_categories = dist_1.shape[0]\n",
"divergence = kl_divergence(dist_1, dist_2)\n",
"print('divergence: {}'.format(divergence))\n",
"print('divergence * num_categories: {}'.format(divergence * num_categories))\n",
"```"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}