blob: 193f88b8e55264493edaa837dab466089ee032c8 [file] [log] [blame]
{"nbformat": 4, "cells": [{"source": "# NDArray - Imperative tensor operations on CPU/GPU\n\nIn _MXNet_, `NDArray` is the core data structure for all mathematical\ncomputations. An `NDArray` represents a multidimensional, fixed-size homogenous\narray. If you're familiar with the scientific computing python package\n[NumPy](http://www.numpy.org/), you might notice that `mxnet.ndarray` is similar\nto `numpy.ndarray`. Like the corresponding NumPy data structure, MXNet's\n`NDArray` enables imperative computation.\n\nSo you might wonder, why not just use NumPy? MXNet offers two compelling\nadvantages. First, MXNet's `NDArray` supports fast execution on a wide range of\nhardware configurations, including CPU, GPU, and multi-GPU machines. _MXNet_\nalso scales to distributed systems in the cloud. Second, MXNet's `NDArray`\nexecutes code lazily, allowing it to automatically parallelize multiple\noperations across the available hardware.\n\nAn `NDArray` is a multidimensional array of numbers with the same type. We\ncould represent the coordinates of a point in 3D space, e.g. `[2, 1, 6]` as a 1D\narray with shape (3). Similarly, we could represent a 2D array. Below, we\npresent an array with length 2 along the first axis and length 3 along the\nsecond axis.\n```\n[[0, 1, 2]\n [3, 4, 5]]\n```\nNote that here the use of \"dimension\" is overloaded. When we say a 2D array, we\nmean an array with 2 axes, not an array with two components.\n\nEach NDArray supports some important attributes that you'll often want to query:\n\n- **ndarray.shape**: The dimensions of the array. It is a tuple of integers\n indicating the length of the array along each axis. For a matrix with `n` rows\n and `m` columns, its `shape` will be `(n, m)`.\n- **ndarray.dtype**: A `numpy` _type_ object describing the type of its\n elements.\n- **ndarray.size**: The total number of components in the array - equal to the\n product of the components of its `shape`\n- **ndarray.context**: The device on which this array is stored, e.g. `cpu()` or\n `gpu(1)`.\n\n## Prerequisites\n\nTo complete this tutorial, we need:\n\n- MXNet. See the instructions for your operating system in [Setup and Installation](http://mxnet.io/install/index.html)\n- [Jupyter](http://jupyter.org/)\n ```\n pip install jupyter\n ```\n- GPUs - A section of this tutorial uses GPUs. If you don't have GPUs on your\nmachine, simply set the variable gpu_device (set in the GPUs section of this \ntutorial) to mx.cpu().\n\n## Array Creation\n\nThere are a few different ways to create an `NDArray`.\n\n* We can create an NDArray from a regular Python list or tuple by using the `array` function:", "cell_type": "markdown", "metadata": {}}, {"source": "import mxnet as mx\n# create a 1-dimensional array with a python list\na = mx.nd.array([1,2,3])\n# create a 2-dimensional array with a nested python list\nb = mx.nd.array([[1,2,3], [2,3,4]])\n{'a.shape':a.shape, 'b.shape':b.shape}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "* We can also create an MXNet NDArray from a `numpy.ndarray` object:", "cell_type": "markdown", "metadata": {}}, {"source": "import numpy as np\nimport math\nc = np.arange(15).reshape(3,5)\n# create a 2-dimensional array from a numpy.ndarray object\na = mx.nd.array(c)\n{'a.shape':a.shape}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We can specify the element type with the option `dtype`, which accepts a numpy\ntype. By default, `float32` is used:", "cell_type": "markdown", "metadata": {}}, {"source": "# float32 is used by default\na = mx.nd.array([1,2,3])\n# create an int32 array\nb = mx.nd.array([1,2,3], dtype=np.int32)\n# create a 16-bit float array\nc = mx.nd.array([1.2, 2.3], dtype=np.float16)\n(a.dtype, b.dtype, c.dtype)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "If we know the size of the desired NDArray, but not the element values, MXNet\noffers several functions to create arrays with placeholder content:", "cell_type": "markdown", "metadata": {}}, {"source": "# create a 2-dimensional array full of zeros with shape (2,3)\na = mx.nd.zeros((2,3))\n# create a same shape array full of ones\nb = mx.nd.ones((2,3))\n# create a same shape array with all elements set to 7\nc = mx.nd.full((2,3), 7)\n# create a same shape whose initial content is random and\n# depends on the state of the memory\nd = mx.nd.empty((2,3))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Printing Arrays\n\nWhen inspecting the contents of an `NDArray`, it's often convenient to first\nextract its contents as a `numpy.ndarray` using the `asnumpy` function. Numpy\nuses the following layout:\n\n- The last axis is printed from left to right,\n- The second-to-last is printed from top to bottom,\n- The rest are also printed from top to bottom, with each slice separated from\n the next by an empty line.", "cell_type": "markdown", "metadata": {}}, {"source": "b = mx.nd.arange(18).reshape((3,2,3))\nb.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Basic Operations\n\nWhen applied to NDArrays, the standard arithmetic operators apply *elementwise*\ncalculations. The returned value is a new array whose content contains the\nresult.", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.ones((2,3))\nb = mx.nd.ones((2,3))\n# elementwise plus\nc = a + b\n# elementwise minus\nd = - c\n# elementwise pow and sin, and then transpose\ne = mx.nd.sin(c**2).T\n# elementwise max\nf = mx.nd.maximum(a, c)\nf.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "As in `NumPy`, `*` represents element-wise multiplication. For matrix-matrix\nmultiplication, use `dot`.", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.arange(4).reshape((2,2))\nb = a * a\nc = mx.nd.dot(a,a)\nprint(\"b: %s, \\n c: %s\" % (b.asnumpy(), c.asnumpy()))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "The assignment operators such as `+=` and `*=` modify arrays in place, and thus\ndon't allocate new memory to create a new array.", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.ones((2,2))\nb = mx.nd.ones(a.shape)\nb += a\nb.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Indexing and Slicing\n\nThe slice operator `[]` applies on axis 0.", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.array(np.arange(6).reshape(3,2))\na[1:2] = 1\na[:].asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We can also slice a particular axis with the method `slice_axis`", "cell_type": "markdown", "metadata": {}}, {"source": "d = mx.nd.slice_axis(a, axis=1, begin=1, end=2)\nd.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Shape Manipulation\n\nUsing `reshape`, we can manipulate any arrays shape as long as the size remains\nunchanged.", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.array(np.arange(24))\nb = a.reshape((2,3,4))\nb.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "The `concat` method stacks multiple arrays along the first axis. Their\nshapes must be the same along the other axes.", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.ones((2,3))\nb = mx.nd.ones((2,3))*2\nc = mx.nd.concat(a,b)\nc.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Reduce\n\nSome functions, like `sum` and `mean` reduce arrays to scalars.", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.ones((2,3))\nb = mx.nd.sum(a)\nb.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We can also reduce an array along a particular axis:", "cell_type": "markdown", "metadata": {}}, {"source": "c = mx.nd.sum_axis(a, axis=1)\nc.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Broadcast\n\nWe can also broadcast an array. Broadcasting operations, duplicate an array's\nvalue along an axis with length 1. The following code broadcasts along axis 1:", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.array(np.arange(6).reshape(6,1))\nb = a.broadcast_to((6,4)) #\nb.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "It's possible to simultaneously broadcast along multiple axes. In the following example, we broadcast along axes 1 and 2:", "cell_type": "markdown", "metadata": {}}, {"source": "c = a.reshape((2,1,1,3))\nd = c.broadcast_to((2,2,2,3))\nd.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Broadcasting can be applied automatically when executing some operations,\ne.g. `*` and `+` on arrays of different shapes.", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.ones((3,2))\nb = mx.nd.ones((1,2))\nc = a + b\nc.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Copies\n\nWhen assigning an NDArray to another Python variable, we copy a reference to the\n*same* NDArray. However, we often need to make a copy of the data, so that we\ncan manipulate the new array without overwriting the original values.", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.ones((2,2))\nb = a\nb is a # will be True", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "The `copy` method makes a deep copy of the array and its data:", "cell_type": "markdown", "metadata": {}}, {"source": "b = a.copy()\nb is a # will be False", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "The above code allocates a new NDArray and then assigns to *b*. When we do not\nwant to allocate additional memory, we can use the `copyto` method or the slice\noperator `[]` instead.", "cell_type": "markdown", "metadata": {}}, {"source": "b = mx.nd.ones(a.shape)\nc = b\nc[:] = a\nd = b\na.copyto(d)\n(c is b, d is b) # Both will be True", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Advanced Topics\n\nMXNet's NDArray offers some advanced features that differentiate it from the\nofferings you'll find in most other libraries.\n\n### GPU Support\n\nBy default, NDArray operators are executed on CPU. But with MXNet, it's easy to\nswitch to another computation resource, such as GPU, when available. Each\nNDArray's device information is stored in `ndarray.context`. When MXNet is\ncompiled with flag `USE_CUDA=1` and the machine has at least one NVIDIA GPU, we\ncan cause all computations to run on GPU 0 by using context `mx.gpu(0)`, or\nsimply `mx.gpu()`. When we have access to two or more GPUs, the 2nd GPU is\nrepresented by `mx.gpu(1)`, etc.\n\n**Note** In order to execute the following section on a cpu set gpu_device to mx.cpu().", "cell_type": "markdown", "metadata": {}}, {"source": "gpu_device=mx.gpu() # Change this to mx.cpu() in absence of GPUs.\n\n\ndef f():\n a = mx.nd.ones((100,100))\n b = mx.nd.ones((100,100))\n c = a + b\n print(c)\n# in default mx.cpu() is used\nf()\n# change the default context to the first GPU\nwith mx.Context(gpu_device):\n f()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We can also explicitly specify the context when creating an array:", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.ones((100, 100), gpu_device)\na", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Currently, MXNet requires two arrays to sit on the same device for\ncomputation. There are several methods for copying data between devices.", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.ones((100,100), mx.cpu())\nb = mx.nd.ones((100,100), gpu_device)\nc = mx.nd.ones((100,100), gpu_device)\na.copyto(c) # copy from CPU to GPU\nd = b + c\ne = b.as_in_context(c.context) + c # same to above\n{'d':d, 'e':e}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "### Serialize From/To (Distributed) Filesystems\n\nMXNet offers two simple ways to save (load) data to (from) disk. The first way\nis to use `pickle`, as you might with any other Python objects. `NDArray` is\npickle-compatible.", "cell_type": "markdown", "metadata": {}}, {"source": "import pickle as pkl\na = mx.nd.ones((2, 3))\n# pack and then dump into disk\ndata = pkl.dumps(a)\npkl.dump(data, open('tmp.pickle', 'wb'))\n# load from disk and then unpack\ndata = pkl.load(open('tmp.pickle', 'rb'))\nb = pkl.loads(data)\nb.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "The second way is to directly dump to disk in binary format by using the `save`\nand `load` methods. We can save/load a single NDArray, or a list of NDArrays:", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.ones((2,3))\nb = mx.nd.ones((5,6))\nmx.nd.save(\"temp.ndarray\", [a,b])\nc = mx.nd.load(\"temp.ndarray\")\nc", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "It's also possible to save or load a dict of NDArrays in this fashion:", "cell_type": "markdown", "metadata": {}}, {"source": "d = {'a':a, 'b':b}\nmx.nd.save(\"temp.ndarray\", d)\nc = mx.nd.load(\"temp.ndarray\")\nc", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "The `load` and `save` methods are preferable to pickle in two respects\n\n1. When using these methods, you can save data from within the Python interface\n and then use it later from another language's binding. For example, if we save\n the data in Python:", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.ones((2, 3))\nmx.nd.save(\"temp.ndarray\", [a,])", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "we can later load it from R:\n```\na <- mx.nd.load(\"temp.ndarray\")\nas.array(a[[1]])\n## [,1] [,2] [,3]\n## [1,] 1 1 1\n## [2,] 1 1 1\n```\n\n2. When a distributed filesystem such as Amazon S3 or Hadoop HDFS is set up, we\n can directly save to and load from it.\n\n```\nmx.nd.save('s3://mybucket/mydata.ndarray', [a,]) # if compiled with USE_S3=1\nmx.nd.save('hdfs///users/myname/mydata.bin', [a,]) # if compiled with USE_HDFS=1\n```\n\n### Lazy Evaluation and Automatic Parallelization\n\nMXNet uses lazy evaluation to achieve superior performance. When we run `a=b+1`\nin Python, the Python thread just pushes this operation into the backend engine\nand then returns. There are two benefits to this approach:\n\n1. The main Python thread can continue to execute other computations once the\n previous one is pushed. It is useful for frontend languages with heavy\n overheads.\n2. It is easier for the backend engine to explore further optimization, such as\n auto parallelization.\n\nThe backend engine can resolve data dependencies and schedule the computations\ncorrectly. It is transparent to frontend users. We can explicitly call the\nmethod `wait_to_read` on the result array to wait until the computation\nfinishes. Operations that copy data from an array to other packages, such as\n`asnumpy`, will implicitly call `wait_to_read`.", "cell_type": "markdown", "metadata": {}}, {"source": "import time\ndef do(x, n):\n \"\"\"push computation into the backend engine\"\"\"\n return [mx.nd.dot(x,x) for i in range(n)]\ndef wait(x):\n \"\"\"wait until all results are available\"\"\"\n for y in x:\n y.wait_to_read()\n\ntic = time.time()\na = mx.nd.ones((1000,1000))\nb = do(a, 50)\nprint('time for all computations are pushed into the backend engine:\\n %f sec' % (time.time() - tic))\nwait(b)\nprint('time for all computations are finished:\\n %f sec' % (time.time() - tic))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Besides analyzing data read and write dependencies, the backend engine is able\nto schedule computations with no dependency in parallel. For example, in the\nfollowing code:", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.ones((2,3))\nb = a + 1\nc = a + 2\nd = b * c", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "the second and third lines can be executed in parallel. The following example\nfirst runs on CPU and then on GPU:", "cell_type": "markdown", "metadata": {}}, {"source": "n = 10\na = mx.nd.ones((1000,1000))\nb = mx.nd.ones((6000,6000), gpu_device)\ntic = time.time()\nc = do(a, n)\nwait(c)\nprint('Time to finish the CPU workload: %f sec' % (time.time() - tic))\nd = do(b, n)\nwait(d)\nprint('Time to finish both CPU/GPU workloads: %f sec' % (time.time() - tic))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Now we issue all workloads at the same time. The backend engine will try to\nparallel the CPU and GPU computations.", "cell_type": "markdown", "metadata": {}}, {"source": "tic = time.time()\nc = do(a, n)\nd = do(b, n)\nwait(c)\nwait(d)\nprint('Both as finished in: %f sec' % (time.time() - tic))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n<!-- INSERT SOURCE DOWNLOAD BUTTONS -->\n\n", "cell_type": "markdown", "metadata": {}}], "metadata": {"display_name": "", "name": "", "language": "python"}, "nbformat_minor": 2}