blob: 9d853de725b685397d0f40e750d4d98571b4ed4e [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"# Differences between NP on MXNet and NumPy\n",
"\n",
"This topic lists known differences between `mxnet.np` and `numpy`. With this quick reference, NumPy users can more easily adopt the MXNet NumPy-like API."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as onp # o means original\n",
"from mxnet import np, npx\n",
"npx.set_np() # Configue MXNet to be NumPy-like"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Missing operators\n",
"\n",
"Many, but not all, operators in NumPy are supported in MXNet. You can find the missing operators in [NP on MXNet reference](/api/python/docs/api/ndarray/index.html). They're displayed in gray blocks instead of having links to their documents. \n",
"\n",
"In addition, an operator might not contain all arguments available in NumPy. For example, MXNet does not support stride. Check the operator document for more details. \n",
"\n",
"## Extra functionalities \n",
"\n",
"The `mxnet.np` module aims to mimic NumPy. Most extra functionalities that enhance NumPy for deep learning use are available on other modules, such as `npx` for operators used in deep learning and `autograd` for automatic differentiation. The `np` module API is not complete. One notable change is GPU support. Creating routines accepts a `ctx` argument:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gpu = npx.gpu() if npx.num_gpus() > 0 else npx.cpu()\n",
"a = np.array(1, ctx=gpu)\n",
"b = np.random.uniform(ctx=gpu)\n",
"(a, b.context)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Methods to move data across devices."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"a.copyto(npx.cpu()), b.as_in_context(npx.cpu())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Default data types\n",
"\n",
"NumPy uses 64-bit floating numbers or 64-bit integers by default."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"onp.array([1,2]).dtype, onp.array([1.2,2.3]).dtype"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"MXNet uses 32-bit floating points as the default date type. It's the default data type for deep learning."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.array([1,2]).dtype, np.array([1.2,2.3]).dtype"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Scalars\n",
"\n",
"NumPy has classes for scalars, whose base class is 'numpy.generic'. The return values of selecting an element and reduce operators are scalars."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"a = onp.array([1,2])\n",
"type(a[0]), type(a.sum())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A scalar is almost identical to a 0-rank tensor (TODO, there may be subtle difference), but it has a different class. You can check the data type with `isinstance`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"b = a[0]\n",
"(b.ndim, b.size, isinstance(b, onp.generic), isinstance(b, onp.integer),\n",
" isinstance(b, onp.int64), isinstance(b, onp.ndarray))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"MXNet returns 0-rank `ndarray` for scalars. (TODO, may consider to add scalar classes later.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"a = np.array([1,2])\n",
"type(a[0]), type(a.sum())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"b = a[0]\n",
"b.ndim, b.size, isinstance(b, np.ndarray)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save\n",
"\n",
"The `save` method in `mxnet.np` saves data into a binary format that's not compatible with NumPy format. For example, it contains the device information. (TODO, needs more discussion here.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"a = np.array(1, ctx=gpu)\n",
"npx.save('a', a)\n",
"npx.load('a')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Matplotlib\n",
"\n",
"Sometimes the MXNet ndarray cannot used by other libraries that accept NumPy input, for example matplotlib. The best practice is converting to NumPy with `asnumpy()`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.plot(np.array([1,2]).asnumpy());"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 4
}