blob: f6773f69062c1c260e5def2b6e8ae789d61e5903 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "f1502018",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"# Step 2: Create a neural network\n",
"\n",
"In this step, you learn how to use NP on Apache MXNet to create neural networks\n",
"in Gluon. In addition to the `np` package that you learned about in the previous\n",
"step [Step 1: Manipulate data with NP on MXNet](./1-nparray.ipynb), you also need to\n",
"import the neural network modules from `gluon`. Gluon includes built-in neural\n",
"network layers in the following two modules:\n",
"\n",
"1. `mxnet.gluon.nn`: NN module that maintained by the mxnet team\n",
"2. `mxnet.gluon.contrib.nn`: Experiemental module that is contributed by the\n",
"community\n",
"\n",
"Use the following commands to import the packages required for this step."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b5edb85d",
"metadata": {},
"outputs": [],
"source": [
"from mxnet import np, npx\n",
"from mxnet.gluon import nn\n",
"npx.set_np() # Change MXNet to the numpy-like mode."
]
},
{
"cell_type": "markdown",
"id": "db45d00c",
"metadata": {},
"source": [
"## Create your neural network's first layer\n",
"\n",
"In this section, you will create a simple neural network with Gluon. One of the\n",
"simplest network you can create is a single **Dense** layer or **densely-\n",
"connected** layer. A dense layer consists of nodes in the input that are\n",
"connected to every node in the next layer. Use the following code example to\n",
"start with a dense layer with five output units."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "65f9bd2d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dense(-1 -> 5, linear)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"layer = nn.Dense(5)\n",
"layer\n",
"# output: Dense(-1 -> 5, linear)"
]
},
{
"cell_type": "markdown",
"id": "1dd2c834",
"metadata": {},
"source": [
"In the example above, the output is `Dense(-1 -> 5, linear)`. The **-1** in the\n",
"output denotes that the size of the input layer is not specified during\n",
"initialization.\n",
"\n",
"You can also call the **Dense** layer with an `in_units` parameter if you know\n",
"the shape of your input unit."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "76eaacee",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dense(3 -> 5, linear)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"layer = nn.Dense(5,in_units=3)\n",
"layer"
]
},
{
"cell_type": "markdown",
"id": "41a603cf",
"metadata": {},
"source": [
"In addition to the `in_units` param, you can also add an activation function to\n",
"the layer using the `activation` param. The Dense layer implements the operation\n",
"\n",
"$$output = \\sigma(W \\cdot X + b)$$\n",
"\n",
"Call the Dense layer with an `activation` parameter to use an activation\n",
"function."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "69a4b1e9",
"metadata": {},
"outputs": [],
"source": [
"layer = nn.Dense(5, in_units=3,activation='relu')"
]
},
{
"cell_type": "markdown",
"id": "1e319ccd",
"metadata": {},
"source": [
"Voila! Congratulations on creating a simple neural network. But for most of your\n",
"use cases, you will need to create a neural network with more than one dense\n",
"layer or with multiple types of other layers. In addition to the `Dense` layer,\n",
"you can find more layers at [mxnet nn layers](../../../api/gluon/nn/index.rst#module-mxnet.gluon.nn)\n",
"\n",
"So now that you have created a neural network, you are probably wondering how to\n",
"pass data into your network?\n",
"\n",
"First, you need to initialize the network weights, if you use the default\n",
"initialization method which draws random values uniformly in the range $[-0.7,\n",
"0.7]$. You can see this in the following example.\n",
"\n",
"**Note**: Initialization is discussed at a little deeper detail in the next\n",
"notebook"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "021e754d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[03:51:58] /work/mxnet/src/storage/storage.cc:202: Using Pooled (Naive) StorageManager for CPU\n"
]
}
],
"source": [
"layer.initialize()"
]
},
{
"cell_type": "markdown",
"id": "2e5925d3",
"metadata": {},
"source": [
"Now that you have initialized your network, you can give it data. Passing data\n",
"through a network is also called a forward pass. You can do a forward pass with\n",
"random data, shown in the following example. First, you create a `(10,3)` shape\n",
"random input `x` and feed the data into the layer to compute the output."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8018b3dd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.00881556, 0.01138476, 0. , 0. , 0.01936117],\n",
" [0. , 0.0035577 , 0.06854778, 0.03361227, 0. ],\n",
" [0.05338536, 0.01661206, 0. , 0.00864646, 0. ],\n",
" [0. , 0. , 0. , 0.03724934, 0.03653988],\n",
" [0. , 0.00657675, 0.00472842, 0. , 0.04593495],\n",
" [0. , 0.00795121, 0. , 0. , 0.0595542 ],\n",
" [0.02296758, 0.01650022, 0. , 0. , 0.03874438],\n",
" [0.02207369, 0.02308735, 0.04558432, 0.01468477, 0. ],\n",
" [0. , 0.02254252, 0. , 0. , 0.03834942],\n",
" [0.08092358, 0.04224379, 0. , 0. , 0. ]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.random.uniform(-1,1,(10,3))\n",
"layer(x)"
]
},
{
"cell_type": "markdown",
"id": "102f8c57",
"metadata": {},
"source": [
"The layer produces a `(10,5)` shape output from your `(10,3)` input.\n",
"\n",
"**When you don't specify the `in_unit` parameter, the system automatically\n",
"infers it during the first time you feed in data during the first forward step\n",
"after you create and initialize the weights.**"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b54c523e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'weight': Parameter (shape=(5, 3), dtype=float32),\n",
" 'bias': Parameter (shape=(5,), dtype=float32)}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"layer.params"
]
},
{
"cell_type": "markdown",
"id": "63e986db",
"metadata": {},
"source": [
"The `weights` and `bias` can be accessed using the `.data()` method."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "656bda95",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.01607367, 0.05928481, -0.0319057 ],\n",
" [-0.05814854, 0.01664302, -0.02215988],\n",
" [-0.04094896, 0.03231322, 0.05914024],\n",
" [ 0.05500493, 0.03504761, 0.05073748],\n",
" [ 0.00943237, -0.06525595, -0.04184696]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"layer.weight.data()"
]
},
{
"cell_type": "markdown",
"id": "011e290e",
"metadata": {},
"source": [
"## Chain layers into a neural network using nn.Sequential\n",
"\n",
"Sequential provides a special way of rapidly building networks when when the\n",
"network architecture follows a common design pattern: the layers look like a\n",
"stack of pancakes. Many networks follow this pattern: a bunch of layers, one\n",
"stacked on top of another, where the output of each layer is fed directly to the\n",
"input to the next layer. To use sequential, simply provide a list of layers\n",
"(pass in the layers by calling `net.add(<Layer goes here!>`). To do this you can\n",
"use your previous example of Dense layers and create a 3-layer multi layer\n",
"perceptron. You can create a sequential block using `nn.Sequential()` method and\n",
"add layers using `add()` method."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a2ab463c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Sequential(\n",
" (0): Dense(3 -> 5, Activation(relu))\n",
" (1): Dense(-1 -> 25, Activation(relu))\n",
" (2): Dense(-1 -> 2, linear)\n",
")"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net = nn.Sequential()\n",
"\n",
"net.add(nn.Dense(5,in_units=3,activation='relu'),\n",
" nn.Dense(25, activation='relu'), nn.Dense(2))\n",
"net"
]
},
{
"cell_type": "markdown",
"id": "4ee4727c",
"metadata": {},
"source": [
"The layers are ordered exactly the way you defined your neural network with\n",
"index starting from 0. You can access the layers by indexing the network using\n",
"`[]`."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "1b66ee0d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dense(-1 -> 25, Activation(relu))"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net[1]"
]
},
{
"cell_type": "markdown",
"id": "c7f5f323",
"metadata": {},
"source": [
"## Create a custom neural network architecture flexibly\n",
"\n",
"`nn.Sequential()` allows you to create your multi-layer neural network with\n",
"existing layers from `gluon.nn`. It also includes a pre-defined `forward()`\n",
"function that sequentially executes added layers. But what if the built-in\n",
"layers are not sufficient for your needs. If you want to create networks like\n",
"ResNet which has complex but repeatable components, how do you create such a\n",
"network?\n",
"\n",
"In gluon, every neural network layer is defined by using a base class\n",
"`nn.Block()`. A Block has one main job - define a forward method that takes some\n",
"input x and generates an output. A Block can just do something simple like apply\n",
"an activation function. It can combine multiple layers together in a single\n",
"block or also combine a bunch of other Blocks together in creative ways to\n",
"create complex networks like Resnet. In this case, you will construct three\n",
"Dense layers. The `forward()` method can then invoke the layers in turn to\n",
"generate its output.\n",
"\n",
"Create a subclass of `nn.Block` and implement two methods by using the following\n",
"code.\n",
"\n",
"- `__init__` create the layers\n",
"- `forward` define the forward function."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "997c5a6b",
"metadata": {},
"outputs": [],
"source": [
"class Net(nn.Block):\n",
" def __init__(self):\n",
" super().__init__()\n",
" def forward(self, x):\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "b9cc930e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MLP(\n",
" (dense1): Dense(-1 -> 5, Activation(relu))\n",
" (dense2): Dense(-1 -> 25, Activation(relu))\n",
" (dense3): Dense(-1 -> 2, linear)\n",
")"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class MLP(nn.Block):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.dense1 = nn.Dense(5,activation='relu')\n",
" self.dense2 = nn.Dense(25,activation='relu')\n",
" self.dense3 = nn.Dense(2)\n",
"\n",
" def forward(self, x):\n",
" layer1 = self.dense1(x)\n",
" layer2 = self.dense2(layer1)\n",
" layer3 = self.dense3(layer2)\n",
" return layer3\n",
"\n",
"net = MLP()\n",
"net"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f275942e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'weight': Parameter (shape=(5, -1), dtype=float32),\n",
" 'bias': Parameter (shape=(5,), dtype=float32)}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net.dense1.params"
]
},
{
"cell_type": "markdown",
"id": "ab1ec782",
"metadata": {},
"source": [
"Each layer includes parameters that are stored in a `Parameter` class. You can\n",
"access them using the `params()` method.\n",
"\n",
"## Creating custom layers using Parameters (Blocks API)\n",
"\n",
"MXNet includes a `Parameter` method to hold your parameters in each layer. You\n",
"can create custom layers using the `Parameter` class to include computation that\n",
"may otherwise be not included in the built-in layers. For example, for a dense\n",
"layer, the weights and biases will be created using the `Parameter` method. But\n",
"if you want to add additional computation to the dense layer, you can create it\n",
"using parameter method.\n",
"\n",
"Instantiate a parameter, e.g weights with a size `(5,0)` using the `shape`\n",
"argument."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "2c3088e0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Parameter (shape=(5, -1), dtype=<class 'numpy.float32'>),\n",
" Parameter (shape=(5, -1), dtype=<class 'numpy.float32'>))"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from mxnet.gluon import Parameter\n",
"\n",
"weight = Parameter(\"custom_parameter_weight\",shape=(5,-1))\n",
"bias = Parameter(\"custom_parameter_bias\",shape=(5,-1))\n",
"\n",
"weight,bias"
]
},
{
"cell_type": "markdown",
"id": "e430e51c",
"metadata": {},
"source": [
"The `Parameter` method includes a `grad_req` argument that specifies how you\n",
"want to capture gradients for this Parameter. Under the hood, that lets gluon\n",
"know that it has to call `.attach_grad()` on the underlying array. By default,\n",
"the gradient is updated everytime the gradient is written to the grad\n",
"`grad_req='write'`.\n",
"\n",
"Now that you know how parameters work, you are ready to create your very own\n",
"fully-connected custom layer.\n",
"\n",
"To create the custom layers using parameters, you can use the same skeleton with\n",
"`nn.Block` base class. You will create a custom dense layer that takes parameter\n",
"x and returns computed `w*x + b` without any activation function"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "9b0c2ca2",
"metadata": {},
"outputs": [],
"source": [
"class custom_layer(nn.Block):\n",
" def __init__(self, out_units, in_units=0):\n",
" super().__init__()\n",
" self.weight = Parameter(\"weight\", shape=(in_units,out_units), allow_deferred_init=True)\n",
" self.bias = Parameter(\"bias\", shape=(out_units,), allow_deferred_init=True)\n",
" def forward(self, x):\n",
" return np.dot(x, self.weight.data()) + self.bias.data()"
]
},
{
"cell_type": "markdown",
"id": "17d3916e",
"metadata": {},
"source": [
"Parameter can be instantiated before the corresponding data is instantiated. For\n",
"example, when you instantiate a Block but the shapes of each parameter still\n",
"need to be inferred, the Parameter will wait for the shape to be inferred before\n",
"allocating memory."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "76991cb3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-0.05604633, -0.06238654, 0.02687173],\n",
" [-0.02687152, -0.04365591, -0.00518382],\n",
" [-0.02849396, -0.09980228, -0.00695815],\n",
" [-0.04527343, -0.00275569, -0.01376584]])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dense = custom_layer(3,in_units=5)\n",
"dense.initialize()\n",
"dense(np.random.uniform(size=(4, 5)))"
]
},
{
"cell_type": "markdown",
"id": "fabc9407",
"metadata": {},
"source": [
"Similarly, you can use the following code to implement a famous network called\n",
"[LeNet](http://yann.lecun.com/exdb/lenet/) through `nn.Block` using the built-in\n",
"`Dense` layer and using `custom_layer` as the last layer"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "58855487",
"metadata": {},
"outputs": [],
"source": [
"class LeNet(nn.Block):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2D(channels=6, kernel_size=3, activation='relu')\n",
" self.pool1 = nn.MaxPool2D(pool_size=2, strides=2)\n",
" self.conv2 = nn.Conv2D(channels=16, kernel_size=3, activation='relu')\n",
" self.pool2 = nn.MaxPool2D(pool_size=2, strides=2)\n",
" self.dense1 = nn.Dense(120, activation=\"relu\")\n",
" self.dense2 = nn.Dense(84, activation=\"relu\")\n",
" self.dense3 = nn.Dense(10)\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.pool1(x)\n",
" x = self.conv2(x)\n",
" x = self.pool2(x)\n",
" x = self.dense1(x)\n",
" x = self.dense2(x)\n",
" x = self.dense3(x)\n",
" return x\n",
"\n",
"lenet = LeNet()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "d69ba80c",
"metadata": {},
"outputs": [],
"source": [
"class LeNet_custom(nn.Block):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2D(channels=6, kernel_size=3, activation='relu')\n",
" self.pool1 = nn.MaxPool2D(pool_size=2, strides=2)\n",
" self.conv2 = nn.Conv2D(channels=16, kernel_size=3, activation='relu')\n",
" self.pool2 = nn.MaxPool2D(pool_size=2, strides=2)\n",
" self.dense1 = nn.Dense(120, activation=\"relu\")\n",
" self.dense2 = nn.Dense(84, activation=\"relu\")\n",
" self.dense3 = custom_layer(10,84)\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.pool1(x)\n",
" x = self.conv2(x)\n",
" x = self.pool2(x)\n",
" x = self.dense1(x)\n",
" x = self.dense2(x)\n",
" x = self.dense3(x)\n",
" return x\n",
"\n",
"lenet_custom = LeNet_custom()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "32c9aff7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Lenet:\n",
"[[-0.00081668 -0.00340701 0.00199039 -0.00121501 -0.00063157 0.00081476\n",
" -0.0011025 -0.00216652 -0.00015363 -0.00110007]]\n",
"Custom Lenet:\n",
"[[-0.02656163 0.04184292 -0.00254037 -0.06494093 -0.00952166 -0.01921579\n",
" 0.05423243 -0.02774546 0.06823301 0.00313227]]\n"
]
}
],
"source": [
"image_data = np.random.uniform(-1,1, (1,1,28,28))\n",
"\n",
"lenet.initialize()\n",
"lenet_custom.initialize()\n",
"\n",
"print(\"Lenet:\")\n",
"print(lenet(image_data))\n",
"\n",
"print(\"Custom Lenet:\")\n",
"print(lenet_custom(image_data))"
]
},
{
"cell_type": "markdown",
"id": "aa9cc4bd",
"metadata": {},
"source": [
"You can use `.data` method to access the weights and bias of a particular layer.\n",
"For example, the following accesses the first layer's weight and sixth layer's bias."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "1c293d07",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((6, 1, 3, 3), (120,))"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lenet.conv1.weight.data().shape, lenet.dense1.bias.data().shape"
]
},
{
"cell_type": "markdown",
"id": "0c94cea7",
"metadata": {},
"source": [
"## Using predefined (pretrained) architectures\n",
"\n",
"Till now, you have seen how to create your own neural network architectures. But\n",
"what if you want to replicate or baseline your dataset using some of the common\n",
"models in computer visions or natural language processing (NLP). Gluon includes\n",
"common architectures that you can directly use. The Gluon Model Zoo provides a\n",
"collection of off-the-shelf models e.g. RESNET, BERT etc. These architectures\n",
"are found at:\n",
"\n",
"- [Gluon CV model zoo](https://cv.gluon.ai/model_zoo/index.html)\n",
"\n",
"- [Gluon NLP model zoo](https://nlp.gluon.ai/model_zoo/index.html)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "5b7afa70",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading /home/jenkins_slave/.mxnet/models/resnet50_v2-ecdde353.zip00383814-e655-4621-a110-5ffefe3eb69c from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/resnet50_v2-ecdde353.zip...\n"
]
},
{
"data": {
"text/plain": [
"(1, 1000)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from mxnet.gluon import model_zoo\n",
"\n",
"net = model_zoo.vision.resnet50_v2(pretrained=True)\n",
"net.hybridize()\n",
"\n",
"dummy_input = np.ones(shape=(1,3,224,224))\n",
"output = net(dummy_input)\n",
"output.shape"
]
},
{
"cell_type": "markdown",
"id": "af120df5",
"metadata": {},
"source": [
"## Deciding the paradigm for your network\n",
"\n",
"In MXNet, Gluon API (Imperative programming paradigm) provides a user friendly\n",
"way for quick prototyping, easy debugging and natural control flow for people\n",
"familiar with python programming.\n",
"\n",
"However, at the backend, MXNET can also convert the network using Symbolic or\n",
"Declarative programming into static graphs with low level optimizations on\n",
"operators. However, static graphs are less flexible because any logic must be\n",
"encoded into the graph as special operators like scan, while_loop and cond. It’s\n",
"also hard to debug.\n",
"\n",
"So how can you make use of symbolic programming while getting the flexibility of\n",
"imperative programming to quickly prototype and debug?\n",
"\n",
"Enter **HybridBlock**\n",
"\n",
"HybridBlocks can run in a fully imperatively way where you define their\n",
"computation with real functions acting on real inputs. But they’re also capable\n",
"of running symbolically, acting on placeholders. Gluon hides most of this under\n",
"the hood so you will only need to know how it works when you want to write your\n",
"own layers."
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "abc3493e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"HybridSequential(\n",
" (0): Dense(3 -> 5, Activation(relu))\n",
" (1): Dense(-1 -> 25, Activation(relu))\n",
" (2): Dense(-1 -> 2, linear)\n",
")"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net_hybrid_seq = nn.HybridSequential()\n",
"\n",
"net_hybrid_seq.add(nn.Dense(5,in_units=3,activation='relu'),\n",
" nn.Dense(25, activation='relu'), nn.Dense(2) )\n",
"net_hybrid_seq"
]
},
{
"cell_type": "markdown",
"id": "581c517f",
"metadata": {},
"source": [
"To compile and optimize `HybridSequential`, you can call its `hybridize` method."
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "af36f9fc",
"metadata": {},
"outputs": [],
"source": [
"net_hybrid_seq.hybridize()"
]
},
{
"cell_type": "markdown",
"id": "64823d61",
"metadata": {},
"source": [
"## Creating custom layers using Parameters (HybridBlocks API)\n",
"\n",
"When you instantiated your custom layer, you specified the input dimension\n",
"`in_units` that initializes the weights with the shape specified by `in_units`\n",
"and `out_units`. If you leave the shape of `in_unit` as unknown, you defer the\n",
"shape to the first forward pass. For the custom layer, you define the\n",
"`infer_shape()` method and let the shape be inferred at runtime."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "eb9756b1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/work/mxnet/python/mxnet/util.py:755: UserWarning: Parameter 'weight' is already initialized, ignoring. Set force_reinit=True to re-initialize.\n",
" return func(*args, **kwargs)\n",
"/work/mxnet/python/mxnet/util.py:755: UserWarning: Parameter 'bias' is already initialized, ignoring. Set force_reinit=True to re-initialize.\n",
" return func(*args, **kwargs)\n"
]
},
{
"data": {
"text/plain": [
"array([[-0.07053316, -0.07457963, 0.01166525],\n",
" [-0.04170407, -0.07482161, 0.00179428],\n",
" [-0.07503258, 0.00660181, -0.01401043],\n",
" [-0.02333996, -0.06775613, 0.01459978]])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class CustomLayer(nn.HybridBlock):\n",
" def __init__(self, out_units, in_units=-1):\n",
" super().__init__()\n",
" self.weight = Parameter(\"weight\", shape=(in_units, out_units), allow_deferred_init=True)\n",
" self.bias = Parameter(\"bias\", shape=(out_units,), allow_deferred_init=True)\n",
"\n",
" def forward(self, x):\n",
" print(self.weight.shape, self.bias.shape)\n",
" return np.dot(x, self.weight.data()) + self.bias.data()\n",
"\n",
" def infer_shape(self, x):\n",
" print(self.weight.shape,x.shape)\n",
" self.weight.shape = (x.shape[-1],self.weight.shape[1])\n",
" dense = CustomLayer(3)\n",
"\n",
"dense.initialize()\n",
"dense(np.random.uniform(size=(4, 5)))"
]
},
{
"cell_type": "markdown",
"id": "844e4d15",
"metadata": {},
"source": [
"### Performance\n",
"\n",
"To get a sense of the speedup from hybridizing, you can compare the performance\n",
"before and after hybridizing by measuring the time it takes to make 1000 forward\n",
"passes through the network."
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "8e0d8df9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Before hybridizing: 0.6034 sec\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"After hybridizing: 0.2799 sec\n"
]
}
],
"source": [
"from time import time\n",
"\n",
"def benchmark(net, x):\n",
" y = net(x)\n",
" start = time()\n",
" for i in range(1,1000):\n",
" y = net(x)\n",
" return time() - start\n",
"\n",
"x_bench = np.random.normal(size=(1,512))\n",
"\n",
"net_hybrid_seq = nn.HybridSequential()\n",
"\n",
"net_hybrid_seq.add(nn.Dense(256,activation='relu'),\n",
" nn.Dense(128, activation='relu'),\n",
" nn.Dense(2))\n",
"net_hybrid_seq.initialize()\n",
"\n",
"print('Before hybridizing: %.4f sec'%(benchmark(net_hybrid_seq, x_bench)))\n",
"net_hybrid_seq.hybridize()\n",
"print('After hybridizing: %.4f sec'%(benchmark(net_hybrid_seq, x_bench)))"
]
},
{
"cell_type": "markdown",
"id": "6f1f1349",
"metadata": {},
"source": [
"Peeling back another layer, you also have a `HybridBlock` which is the hybrid\n",
"version of the `Block` API.\n",
"\n",
"Similar to the `Blocks` API, you define a `forward` function for `HybridBlock`\n",
"that takes an input `x`. MXNet takes care of hybridizing the model at the\n",
"backend so you don't have to make changes to your code to convert it to a\n",
"symbolic paradigm."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "07c6cce6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Before hybridizing: 0.5799 sec\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"After hybridizing: 0.2603 sec\n"
]
}
],
"source": [
"from mxnet.gluon import HybridBlock\n",
"\n",
"class MLP_Hybrid(HybridBlock):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.dense1 = nn.Dense(256,activation='relu')\n",
" self.dense2 = nn.Dense(128,activation='relu')\n",
" self.dense3 = nn.Dense(2)\n",
" def forward(self, x):\n",
" layer1 = self.dense1(x)\n",
" layer2 = self.dense2(layer1)\n",
" layer3 = self.dense3(layer2)\n",
" return layer3\n",
"\n",
"net_hybrid = MLP_Hybrid()\n",
"net_hybrid.initialize()\n",
"\n",
"print('Before hybridizing: %.4f sec'%(benchmark(net_hybrid, x_bench)))\n",
"net_hybrid.hybridize()\n",
"print('After hybridizing: %.4f sec'%(benchmark(net_hybrid, x_bench)))"
]
},
{
"cell_type": "markdown",
"id": "12f5f4a9",
"metadata": {},
"source": [
"Given a HybridBlock whose forward computation consists of going through other\n",
"HybridBlocks, you can compile that section of the network by calling the\n",
"HybridBlocks `.hybridize()` method.\n",
"\n",
"All of MXNet’s predefined layers are HybridBlocks. This means that any network\n",
"consisting entirely of predefined MXNet layers can be compiled and run at much\n",
"faster speeds by calling `.hybridize()`.\n",
"\n",
"## Saving and Loading your models\n",
"\n",
"The Blocks API also includes saving your models during and after training so\n",
"that you can host the model for inference or avoid training the model again from\n",
"scratch. Another reason would be to train your model using one language (like\n",
"Python that has a lot of tools for training) and run inference using a different\n",
"language.\n",
"\n",
"There are two ways to save your model in MXNet.\n",
"1. Save/load the model weights/parameters only\n",
"2. Save/load the model weights/parameters and the architectures\n",
"\n",
"\n",
"### 1. Save/load the model weights/parameters only\n",
"\n",
"You can use `save_parameters` and `load_parameters` method to save and load the\n",
"model weights. Take your simplest model `layer` and save your parameters first.\n",
"The model parameters are the params that you save **after** you train your\n",
"model."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "abcaf4e1",
"metadata": {},
"outputs": [],
"source": [
"file_name = 'layer.params'\n",
"layer.save_parameters(file_name)"
]
},
{
"cell_type": "markdown",
"id": "70a86a96",
"metadata": {},
"source": [
"And now load this model again. To load the parameters into a model, you will\n",
"first have to build the model. To do this, you will need to create a simple\n",
"function to build it."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "2bd2f6b1",
"metadata": {},
"outputs": [],
"source": [
"def build_model():\n",
" layer = nn.Dense(5, in_units=3,activation='relu')\n",
" return layer\n",
"\n",
"layer_new = build_model()"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "47024147",
"metadata": {},
"outputs": [],
"source": [
"layer_new.load_parameters('layer.params')"
]
},
{
"cell_type": "markdown",
"id": "a1647f15",
"metadata": {},
"source": [
"**Note**: The `save_parameters` and `load_parameters` method is used for models\n",
"that use a `Block` method instead of `HybridBlock` method to build the model.\n",
"These models may have complex architectures where the model architectures may\n",
"change during execution. E.g. if you have a model that uses an if-else\n",
"conditional statement to choose between two different architectures.\n",
"\n",
"### 2. Save/load the model weights/parameters and the architectures\n",
"\n",
"For models that use the **HybridBlock**, the model architecture stays static and\n",
"do no change during execution. Therefore both model parameters **AND**\n",
"architecture can be saved and loaded using `export`, `imports` methods.\n",
"\n",
"Now look at your `MLP_Hybrid` model and export the model using the `export`\n",
"function. The export function will export the model architecture into a `.json`\n",
"file and model parameters into a `.params` file."
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "0311a523",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('MLP_hybrid-symbol.json', 'MLP_hybrid-0000.params')"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net_hybrid.export('MLP_hybrid')"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "12bea807",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('MLP_hybrid-symbol.json', 'MLP_hybrid-0000.params')"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net_hybrid.export('MLP_hybrid')"
]
},
{
"cell_type": "markdown",
"id": "380e7cdc",
"metadata": {},
"source": [
"Similarly, to load this model back, you can use `gluon.nn.SymbolBlock`. To\n",
"demonstrate that, load the network serialized above."
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "f2322422",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"with warnings.catch_warnings():\n",
" warnings.simplefilter(\"ignore\")\n",
" net_loaded = nn.SymbolBlock.imports(\"MLP_hybrid-symbol.json\",\n",
" ['data'], \"MLP_hybrid-0000.params\",\n",
" device=None)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "c7bc1d0f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.13653663, -0.07247495]])"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net_loaded(x_bench)"
]
},
{
"cell_type": "markdown",
"id": "e829ef5a",
"metadata": {},
"source": [
"## Visualizing your models\n",
"\n",
"In MXNet, the `Block.Summary()` method allows you to view the block’s shape\n",
"arguments and view the block’s parameters. When you combine multiple blocks into\n",
"a model, the `summary()` applied on the model allows you to view each block’s\n",
"summary, the total parameters, and the order of the blocks within the model. To\n",
"do this the `Block.summary()` method requires one forward pass of the data,\n",
"through your network, in order to create the graph necessary for capturing the\n",
"corresponding shapes and parameters. Additionally, this method should be called\n",
"before the hybridize method, since the hybridize method converts the graph into\n",
"a symbolic one, potentially changing the operations for optimal computation.\n",
"\n",
"Look at the following examples\n",
"\n",
"- layer: our single layer network\n",
"- Lenet: a non-hybridized LeNet network\n",
"- net_Hybrid: our MLP Hybrid network"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "16668c82",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================================\n",
" Input (10, 3) 0\n",
" Activation-1 (10, 5) 0\n",
" Dense-2 (10, 5) 20\n",
"================================================================================\n",
"Parameters in forward computation graph, duplicate included\n",
" Total params: 20\n",
" Trainable params: 20\n",
" Non-trainable params: 0\n",
"Shared params in forward computation graph: 0\n",
"Unique parameters in model: 20\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"layer.summary(x)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "7ece38c2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================================\n",
" Input (1, 1, 28, 28) 0\n",
" Activation-1 (1, 6, 26, 26) 0\n",
" Conv2D-2 (1, 6, 26, 26) 60\n",
" MaxPool2D-3 (1, 6, 13, 13) 0\n",
" Activation-4 (1, 16, 11, 11) 0\n",
" Conv2D-5 (1, 16, 11, 11) 880\n",
" MaxPool2D-6 (1, 16, 5, 5) 0\n",
" Activation-7 (1, 120) 0\n",
" Dense-8 (1, 120) 48120\n",
" Activation-9 (1, 84) 0\n",
" Dense-10 (1, 84) 10164\n",
" Dense-11 (1, 10) 850\n",
" LeNet-12 (1, 10) 0\n",
"================================================================================\n",
"Parameters in forward computation graph, duplicate included\n",
" Total params: 60074\n",
" Trainable params: 60074\n",
" Non-trainable params: 0\n",
"Shared params in forward computation graph: 0\n",
"Unique parameters in model: 60074\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"lenet.summary(image_data)"
]
},
{
"cell_type": "markdown",
"id": "da1e1a8f",
"metadata": {},
"source": [
"You are able to print the summaries of the two networks `layer` and `lenet`\n",
"easily since you didn't hybridize the two networks. However, the last network\n",
"`net_Hybrid` was hybridized above and throws an `AssertionError` if you try\n",
"`net_Hybrid.summary(x_bench)`. To print the summary for `net_Hybrid`, call\n",
"another instance of the same network and instantiate it for our summary and then\n",
"hybridize it"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "6f429bc7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================================\n",
" Input (1, 512) 0\n",
" Activation-1 (1, 256) 0\n",
" Dense-2 (1, 256) 131328\n",
" Activation-3 (1, 128) 0\n",
" Dense-4 (1, 128) 32896\n",
" Dense-5 (1, 2) 258\n",
" MLP_Hybrid-6 (1, 2) 0\n",
"================================================================================\n",
"Parameters in forward computation graph, duplicate included\n",
" Total params: 164482\n",
" Trainable params: 164482\n",
" Non-trainable params: 0\n",
"Shared params in forward computation graph: 0\n",
"Unique parameters in model: 164482\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"net_hybrid_summary = MLP_Hybrid()\n",
"\n",
"net_hybrid_summary.initialize()\n",
"\n",
"net_hybrid_summary.summary(x_bench)\n",
"\n",
"net_hybrid_summary.hybridize()"
]
},
{
"cell_type": "markdown",
"id": "f76b4ea5",
"metadata": {},
"source": [
"## Next steps:\n",
"\n",
"Now that you have created a neural network, learn how to automatically compute\n",
"the gradients in [Step 3: Automatic differentiation with autograd](./3-autograd.ipynb)."
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}