blob: 8075953835debefd7d27eff3d7f190df1cafc770 [file] [log] [blame]
{"nbformat": 4, "cells": [{"source": "# Automatic differentiation\n\nMXNet supports automatic differentiation with the `autograd` package.\n`autograd` allows you to differentiate a graph of NDArray operations\nwith the chain rule.\nThis is called define-by-run, i.e., the network is defined on-the-fly by\nrunning forward computation. You can define exotic network structures\nand differentiate them, and each iteration can have a totally different\nnetwork structure.", "cell_type": "markdown", "metadata": {}}, {"source": "import mxnet as mx\nfrom mxnet import autograd", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "To use `autograd`, we must first mark variables that require gradient and\nattach gradient buffers to them:", "cell_type": "markdown", "metadata": {}}, {"source": "x = mx.nd.array([[1, 2], [3, 4]])\nx.attach_grad()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Now we can define the network while running forward computation by wrapping\nit inside a `record` (operations out of `record` does not define\na graph and cannot be differentiated):", "cell_type": "markdown", "metadata": {}}, {"source": "with autograd.record():\n y = x * 2\n z = y * x", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Let's backprop with `z.backward()`, which is equivalent to\n`z.backward(mx.nd.ones_like(z))`. When z has more than one entry, `z.backward()`\nis equivalent to `mx.nd.sum(z).backward()`:", "cell_type": "markdown", "metadata": {}}, {"source": "z.backward()\nprint(x.grad)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\nNow, let's see if this is the expected output.\n\nHere, y = f(x), z = f(y) = f(g(x))\nwhich means y = 2 * x and z = 2 * x * x.\n\nAfter, doing backprop with `z.backward()`, we will get gradient dz/dx as follows:\n\ndy/dx = 2,\ndz/dx = 4 * x\n\nSo, we should get x.grad as an array of [[4, 8],[12, 16]].\n\n<!-- INSERT SOURCE DOWNLOAD BUTTONS -->\n\n", "cell_type": "markdown", "metadata": {}}], "metadata": {"display_name": "", "name": "", "language": "python"}, "nbformat_minor": 2}