blob: f1a8f43187666d5215844eb3224526bf7609a709 [file] [log] [blame]
{"nbformat": 4, "cells": [{"source": "# RowSparseNDArray - NDArray for Sparse Gradient Updates\n\n## Motivation\n\nMany real world datasets deal with high dimensional sparse feature vectors. When learning\nthe weights of models with sparse datasets, the derived gradients of the weights could be sparse.\n\nLet's say we perform a matrix multiplication of ``X`` and ``W``, where ``X`` is a 1x2 matrix, and ``W`` is a 2x3 matrix. Let ``Y`` be the matrix multiplication of the two matrices:", "cell_type": "markdown", "metadata": {}}, {"source": "import mxnet as mx\nX = mx.nd.array([[1,0]])\nW = mx.nd.array([[3,4,5], [6,7,8]])\nY = mx.nd.dot(X, W)\n{'X': X, 'W': W, 'Y': Y}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "As you can see,\n\n```\nY[0][0] = X[0][0] * W[0][0] + X[0][1] * W[1][0] = 1 * 3 + 0 * 6 = 3\nY[0][1] = X[0][0] * W[0][1] + X[0][1] * W[1][1] = 1 * 4 + 0 * 7 = 4\nY[0][2] = X[0][0] * W[0][2] + X[0][1] * W[1][2] = 1 * 5 + 0 * 8 = 5\n```\n\nWhat about dY / dW, the gradient for ``W``? Let's call it ``grad_W``. To start with, the shape of ``grad_W`` is the same as that of ``W`` as we are taking the derivatives with respect to ``W``, which is 2x3. Then we calculate each entry in ``grad_W`` as follows:\n\n```\ngrad_W[0][0] = X[0][0] = 1\ngrad_W[0][1] = X[0][0] = 1\ngrad_W[0][2] = X[0][0] = 1\ngrad_W[1][0] = X[0][1] = 0\ngrad_W[1][1] = X[0][1] = 0\ngrad_W[1][2] = X[0][1] = 0\n```\n\nAs a matter of fact, you can calculate ``grad_W`` by multiplying the transpose of ``X`` with a matrix of ones:", "cell_type": "markdown", "metadata": {}}, {"source": "grad_W = mx.nd.dot(X, mx.nd.ones_like(Y), transpose_a=True)\ngrad_W", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "As you can see, row 0 of ``grad_W`` contains non-zero values while row 1 of ``grad_W`` does not. Why did that happen?\nIf you look at how ``grad_W`` is calculated, notice that since column 1 of ``X`` is filled with zeros, row 1 of ``grad_W`` is filled with zeros too.\n\nIn the real world, gradients for parameters that interact with sparse inputs ususally have gradients where many row slices are completely zeros. Storing and manipulating such sparse matrices with many row slices of all zeros in the default dense structure results in wasted memory and processing on the zeros. More importantly, many gradient based optimization methods such as SGD, [AdaGrad](https://stanford.edu/~jduchi/projects/DuchiHaSi10_colt.pdf) and [Adam](https://arxiv.org/pdf/1412.6980.pdf)\ntake advantage of sparse gradients and prove to be efficient and effective. \n**In MXNet, the ``RowSparseNDArray`` stores the matrix in ``row sparse`` format, which is designed for arrays of which most row slices are all zeros.**\nIn this tutorial, we will describe what the row sparse format is and how to use RowSparseNDArray for sparse gradient updates in MXNet.\n\n## Prerequisites\n\nTo complete this tutorial, we need:\n\n- MXNet. See the instructions for your operating system in [Setup and Installation](https://mxnet.io/get_started/install.html)\n- [Jupyter](http://jupyter.org/)\n ```\n pip install jupyter\n ```\n- Basic knowledge of NDArray in MXNet. See the detailed tutorial for NDArray in [NDArray - Imperative tensor operations on CPU/GPU](https://mxnet.incubator.apache.org/tutorials/basic/ndarray.html)\n- Understanding of [automatic differentiation with autograd](http://gluon.mxnet.io/chapter01_crashcourse/autograd.html)\n- GPUs - A section of this tutorial uses GPUs. If you don't have GPUs on your\nmachine, simply set the variable `gpu_device` (set in the GPUs section of this\ntutorial) to `mx.cpu()`\n\n## Row Sparse Format\n\nA RowSparseNDArray represents a multidimensional NDArray using two separate 1D arrays:\n`data` and `indices`.\n\n- data: an NDArray of any dtype with shape `[D0, D1, ..., Dn]`.\n- indices: a 1D int64 NDArray with shape `[D0]` with values sorted in ascending order.\n\nThe ``indices`` array stores the indices of the row slices with non-zeros,\nwhile the values are stored in ``data`` array. The corresponding NDArray `dense` represented by RowSparseNDArray `rsp` has\n\n``dense[rsp.indices[i], :, :, :, ...] = rsp.data[i, :, :, :, ...]``\n\nA RowSparseNDArray is typically used to represent non-zero row slices of a large NDArray of shape [LARGE0, D1, .. , Dn] where LARGE0 >> D0 and most row slices are zeros.\n\nGiven this two-dimension matrix:", "cell_type": "markdown", "metadata": {}}, {"source": "[[ 1, 2, 3],\n [ 0, 0, 0],\n [ 4, 0, 5],\n [ 0, 0, 0],\n [ 0, 0, 0]]", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "The row sparse representation would be:\n- `data` array holds all the non-zero row slices of the array.\n- `indices` array stores the row index for each row slice with non-zero elements.", "cell_type": "markdown", "metadata": {}}, {"source": "data = [[1, 2, 3], [4, 0, 5]]\nindices = [0, 2]", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "`RowSparseNDArray` supports multidimensional arrays. Given this 3D tensor:", "cell_type": "markdown", "metadata": {}}, {"source": "[[[1, 0],\n [0, 2],\n [3, 4]],\n\n [[5, 0],\n [6, 0],\n [0, 0]],\n\n [[0, 0],\n [0, 0],\n [0, 0]]]", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "The row sparse representation would be (with `data` and `indices` defined the same as above):", "cell_type": "markdown", "metadata": {}}, {"source": "data = [[[1, 0], [0, 2], [3, 4]], [[5, 0], [6, 0], [0, 0]]]\nindices = [0, 1]", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "``RowSparseNDArray`` is a subclass of ``NDArray``. If you query **stype** of a RowSparseNDArray,\nthe value will be **\"row_sparse\"**.\n\n## Array Creation\n\nYou can create a `RowSparseNDArray` with data and indices by using the `row_sparse_array` function:", "cell_type": "markdown", "metadata": {}}, {"source": "import mxnet as mx\nimport numpy as np\n# Create a RowSparseNDArray with python lists\nshape = (6, 2)\ndata_list = [[1, 2], [3, 4]]\nindices_list = [1, 4]\na = mx.nd.sparse.row_sparse_array((data_list, indices_list), shape=shape)\n# Create a RowSparseNDArray with numpy arrays\ndata_np = np.array([[1, 2], [3, 4]])\nindices_np = np.array([1, 4])\nb = mx.nd.sparse.row_sparse_array((data_np, indices_np), shape=shape)\n{'a':a, 'b':b}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Function Overview\n\nSimilar to `CSRNDArray`, the are several functions with `RowSparseNDArray` that behave the same way. In the code blocks below you can try out these common functions:\n\n- **.dtype** - to set the data type\n- **.asnumpy** - to cast as a numpy array for inspecting it\n- **.data** - to access the data array\n- **.indices** - to access the indices array\n- **.tostype** - to set the storage type\n- **.cast_storage** - to convert the storage type\n- **.copy** - to copy the array\n- **.copyto** - to copy to deep copy an existing array\n\n\n## Setting Type\n\nYou can create a `RowSparseNDArray` from another specifying the element data type with the option `dtype`, which accepts a numpy type. By default, `float32` is used.", "cell_type": "markdown", "metadata": {}}, {"source": "# Float32 is used by default\nc = mx.nd.sparse.array(a)\n# Create a 16-bit float array\nd = mx.nd.array(a, dtype=np.float16)\n(c.dtype, d.dtype)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Inspecting Arrays\n\nAs with `CSRNDArray`, you can inspect the contents of a `RowSparseNDArray` by filling\nits contents into a dense `numpy.ndarray` using the `asnumpy` function.", "cell_type": "markdown", "metadata": {}}, {"source": "a.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "You can inspect the internal storage of a RowSparseNDArray by accessing attributes such as `indices` and `data`:", "cell_type": "markdown", "metadata": {}}, {"source": "# Access data array\ndata = a.data\n# Access indices array\nindices = a.indices\n{'a.stype': a.stype, 'data':data, 'indices':indices}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Storage Type Conversion\n\nYou can convert an NDArray to a RowSparseNDArray and vice versa by using the `tostype` function:", "cell_type": "markdown", "metadata": {}}, {"source": "# Create a dense NDArray\nones = mx.nd.ones((2,2))\n# Cast the storage type from `default` to `row_sparse`\nrsp = ones.tostype('row_sparse')\n# Cast the storage type from `row_sparse` to `default`\ndense = rsp.tostype('default')\n{'rsp':rsp, 'dense':dense}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "You can also convert the storage type by using the `cast_storage` operator:", "cell_type": "markdown", "metadata": {}}, {"source": "# Create a dense NDArray\nones = mx.nd.ones((2,2))\n# Cast the storage type to `row_sparse`\nrsp = mx.nd.sparse.cast_storage(ones, 'row_sparse')\n# Cast the storage type to `default`\ndense = mx.nd.sparse.cast_storage(rsp, 'default')\n{'rsp':rsp, 'dense':dense}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Copies\n\nYou can use the `copy` method which makes a deep copy of the array and its data, and returns a new array.\nWe can also use the `copyto` method or the slice operator `[]` to deep copy to an existing array.", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.ones((2,2)).tostype('row_sparse')\nb = a.copy()\nc = mx.nd.sparse.zeros('row_sparse', (2,2))\nc[:] = a\nd = mx.nd.sparse.zeros('row_sparse', (2,2))\na.copyto(d)\n{'b is a': b is a, 'b.asnumpy()':b.asnumpy(), 'c.asnumpy()':c.asnumpy(), 'd.asnumpy()':d.asnumpy()}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "If the storage types of source array and destination array do not match,\nthe storage type of destination array will not change when copying with `copyto` or the slice operator `[]`. The source array will be temporarily converted to desired storage type before the copy.", "cell_type": "markdown", "metadata": {}}, {"source": "e = mx.nd.sparse.zeros('row_sparse', (2,2))\nf = mx.nd.sparse.zeros('row_sparse', (2,2))\ng = mx.nd.ones(e.shape)\ne[:] = g\ng.copyto(f)\n{'e.stype':e.stype, 'f.stype':f.stype, 'g.stype':g.stype}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Retain Row Slices\n\nYou can retain a subset of row slices from a RowSparseNDArray specified by their row indices.", "cell_type": "markdown", "metadata": {}}, {"source": "data = [[1, 2], [3, 4], [5, 6]]\nindices = [0, 2, 3]\nrsp = mx.nd.sparse.row_sparse_array((data, indices), shape=(5, 2))\n# Retain row 0 and row 1\nrsp_retained = mx.nd.sparse.retain(rsp, mx.nd.array([0, 1]))\n{'rsp.asnumpy()': rsp.asnumpy(), 'rsp_retained': rsp_retained, 'rsp_retained.asnumpy()': rsp_retained.asnumpy()}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Sparse Operators and Storage Type Inference\n\nOperators that have specialized implementation for sparse arrays can be accessed in ``mx.nd.sparse``. You can read the [mxnet.ndarray.sparse API documentation](http://mxnet.io/versions/master/api/python/ndarray/sparse.html) to find what sparse operators are available.", "cell_type": "markdown", "metadata": {}}, {"source": "shape = (3, 5)\ndata = [7, 8, 9]\nindptr = [0, 2, 2, 3]\nindices = [0, 2, 1]\n# A csr matrix as lhs\nlhs = mx.nd.sparse.csr_matrix((data, indices, indptr), shape=shape)\n# A dense matrix as rhs\nrhs = mx.nd.ones((3, 2))\n# row_sparse result is inferred from sparse operator dot(csr.T, dense) based on input stypes\ntranspose_dot = mx.nd.sparse.dot(lhs, rhs, transpose_a=True)\n{'transpose_dot': transpose_dot, 'transpose_dot.asnumpy()': transpose_dot.asnumpy()}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "For any sparse operator, the storage type of output array is inferred based on inputs. You can either read the documentation or inspect the `stype` attribute of output array to know what storage type is inferred:", "cell_type": "markdown", "metadata": {}}, {"source": "a = transpose_dot.copy()\nb = a * 2 # b will be a RowSparseNDArray since zero multiplied by 2 is still zero\nc = a + mx.nd.ones((5, 2)) # c will be a dense NDArray\n{'b.stype':b.stype, 'c.stype':c.stype}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "For operators that don't specialize in sparse arrays, you can still use them with sparse inputs with some performance penalty.\nIn MXNet, dense operators require all inputs and outputs to be in the dense format.\n\nIf sparse inputs are provided, MXNet will convert sparse inputs into dense ones temporarily so that the dense operator can be used.\n\nIf sparse outputs are provided, MXNet will convert the dense outputs generated by the dense operator into the provided sparse format.\n\nFor operators that don't specialize in sparse arrays, you can still use them with sparse inputs with some performance penalty.", "cell_type": "markdown", "metadata": {}}, {"source": "e = mx.nd.sparse.zeros('row_sparse', a.shape)\nd = mx.nd.log(a) # dense operator with a sparse input\ne = mx.nd.log(a, out=e) # dense operator with a sparse output\n{'a.stype':a.stype, 'd.stype':d.stype, 'e.stype':e.stype} # stypes of a and e will be not changed", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Note that warning messages will be printed when such a storage fallback event happens. If you are using jupyter notebook, the warning message will be printed in your terminal console.\n\n## Sparse Optimizers\n\nIn MXNet, sparse gradient updates are applied when weight, state and gradient are all in `row_sparse` storage.\nThe sparse optimizers only update the row slices of the weight and the states whose indices appear\nin `gradient.indices`. For example, the default update rule for SGD optimizer is:\n\n```\nrescaled_grad = learning_rate * rescale_grad * clip(grad, clip_gradient) + weight_decay * weight\nstate = momentum * state + rescaled_grad\nweight = weight - state\n```\n\nMeanwhile, the sparse update rule for SGD optimizer is:\n\n```\nfor row in grad.indices:\n rescaled_grad[row] = learning_rate * rescale_grad * clip(grad[row], clip_gradient) + weight_decay * weight[row]\n state[row] = momentum[row] * state[row] + rescaled_grad[row]\n weight[row] = weight[row] - state[row]\n```", "cell_type": "markdown", "metadata": {}}, {"source": "# Create weight\nshape = (4, 2)\nweight = mx.nd.ones(shape).tostype('row_sparse')\n# Create gradient\ndata = [[1, 2], [4, 5]]\nindices = [1, 2]\ngrad = mx.nd.sparse.row_sparse_array((data, indices), shape=shape)\nsgd = mx.optimizer.SGD(learning_rate=0.01, momentum=0.01)\n# Create momentum\nmomentum = sgd.create_state(0, weight)\n# Before the update\n{\"grad.asnumpy()\":grad.asnumpy(), \"weight.asnumpy()\":weight.asnumpy(), \"momentum.asnumpy()\":momentum.asnumpy()}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "sgd.update(0, weight, grad, momentum)\n# Only row 0 and row 2 are updated for both weight and momentum\n{\"weight.asnumpy()\":weight.asnumpy(), \"momentum.asnumpy()\":momentum.asnumpy()}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Note that both [mxnet.optimizer.SGD](https://mxnet.incubator.apache.org/api/python/optimization.html#mxnet.optimizer.SGD)\nand [mxnet.optimizer.Adam](https://mxnet.incubator.apache.org/api/python/optimization.html#mxnet.optimizer.Adam) support sparse updates in MXNet.\n\n## Advanced Topics\n\n### GPU Support\n\nBy default, RowSparseNDArray operators are executed on CPU. In MXNet, GPU support for RowSparseNDArray is experimental\nwith only a few sparse operators such as cast_storage and dot.\n\nTo create a RowSparseNDArray on gpu, we need to explicitly specify the context:\n\n**Note** If a GPU is not available, an error will be reported in the following section. In order to execute it on a cpu, set gpu_device to mx.cpu().", "cell_type": "markdown", "metadata": {}}, {"source": "import sys\ngpu_device=mx.gpu() # Change this to mx.cpu() in absence of GPUs.\ntry:\n a = mx.nd.sparse.zeros('row_sparse', (100, 100), ctx=gpu_device)\n a\nexcept mx.MXNetError as err:\n sys.stderr.write(str(err))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n\n<!-- INSERT SOURCE DOWNLOAD BUTTONS -->\n\n\n\n", "cell_type": "markdown", "metadata": {}}], "metadata": {"display_name": "", "name": "", "language": "python"}, "nbformat_minor": 2}