blob: 76e66b38e6cd1ef378875cbe2ce217cbaecb3a39 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "2f5a7709",
"metadata": {},
"source": [
"<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n",
"<!--- or more contributor license agreements. See the NOTICE file -->\n",
"<!--- distributed with this work for additional information -->\n",
"<!--- regarding copyright ownership. The ASF licenses this file -->\n",
"<!--- to you under the Apache License, Version 2.0 (the -->\n",
"<!--- \"License\"); you may not use this file except in compliance -->\n",
"<!--- with the License. You may obtain a copy of the License at -->\n",
"\n",
"<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n",
"\n",
"<!--- Unless required by applicable law or agreed to in writing, -->\n",
"<!--- software distributed under the License is distributed on an -->\n",
"<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n",
"<!--- KIND, either express or implied. See the License for the -->\n",
"<!--- specific language governing permissions and limitations -->\n",
"<!--- under the License. -->\n",
"\n",
"# Custom Loss Blocks\n",
"\n",
"All neural networks need a loss function for training. A loss function is a quantitive measure of how bad the predictions of the network are when compared to ground truth labels. Given this score, a network can improve by iteratively updating its weights to minimise this loss. Some tasks use a combination of multiple loss functions, but often you'll just use one. MXNet Gluon provides a number of the most commonly used loss functions, and you'll choose certain functions depending on your network and task. Some common task and loss function pairs include:\n",
"\n",
"- Regression: [L1Loss](/api/python/docs/api/gluon/loss/index.html#mxnet.gluon.loss.L1Loss), [L2Loss](/api/python/docs/api/gluon/loss/index.html#mxnet.gluon.loss.L2Loss)\n",
"- Classification: [SigmoidBinaryCrossEntropyLoss](/api/python/docs/api/gluon/loss/index.html#mxnet.gluon.loss.SigmoidBinaryCrossEntropyLoss), [SoftmaxBinaryCrossEntropyLoss](/api/python/docs/api/gluon/loss/index.html#mxnet.gluon.loss.SoftmaxBinaryCrossEntropyLoss)\n",
"- Embeddings: [HingeLoss](/api/python/docs/api/gluon/loss/index.html#mxnet.gluon.loss.HingeLoss)\n",
"\n",
"However, we may sometimes want to solve problems that require customized loss functions; this tutorial shows how we can do that in Gluon. We will implement contrastive loss which is typically used in Siamese networks."
]
},
{
"cell_type": "markdown",
"id": "656dd68b",
"metadata": {},
"source": [
"```python\n",
"import matplotlib.pyplot as plt\n",
"import mxnet as mx\n",
"from mxnet import autograd, gluon, nd\n",
"from mxnet.gluon.loss import Loss\n",
"import random\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "cec1a46a",
"metadata": {},
"source": [
"### What is Contrastive Loss\n",
"\n",
"[Contrastive loss](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf) is a distance-based loss function. During training, pairs of images are fed into a model. If the images are similar, the loss function will return 0, otherwise 1.\n",
"\n",
"<img src=\"images/contrastive_loss.jpeg\" width=\"400\">\n",
"\n",
"*Y* is a binary label indicating similarity between training images. Contrastive loss uses the Euclidean distance *D* between images and is the sum of 2 terms:\n",
" - the loss for a pair of similar points\n",
" - the loss for a pair of dissimilar points\n",
"\n",
"The loss function uses a margin *m* which is has the effect that dissimlar pairs only contribute if their loss is within a certain margin.\n",
"\n",
"In order to implement such a customized loss function in Gluon, we only need to define a new class that is inheriting from the [Loss](/api/python/docs/api/gluon/loss/index.html#mxnet.gluon.loss.Loss) base class. We then define the contrastive loss logic in the [hybrid_forward](/api/python/docs/api/gluon/hybrid_block.html#mxnet.gluon.HybridBlock.hybrid_forward) method. This method takes the images `image1`, `image2` and the label which defines whether `image1` and `image2` are similar (=0) or dissimilar (=1). The input F is an `mxnet.ndarry` or an `mxnet.symbol` if we hybridize the network. Gluon's `Loss` base class is in fact a [HybridBlock](/api/python/docs/api/gluon/hybrid_block.html). This means we can either run imperatively or symbolically. When we hybridize our custom loss function, we can get performance speedups."
]
},
{
"cell_type": "markdown",
"id": "5a29ed8a",
"metadata": {},
"source": [
"```python\n",
"class ContrastiveLoss(Loss):\n",
" def __init__(self, margin=6., weight=None, batch_axis=0, **kwargs):\n",
" super(ContrastiveLoss, self).__init__(weight, batch_axis, **kwargs)\n",
" self.margin = margin\n",
"\n",
" def hybrid_forward(self, F, image1, image2, label):\n",
" distances = image1 - image2\n",
" distances_squared = F.sum(F.square(distances), 1, keepdims=True)\n",
" euclidean_distances = F.sqrt(distances_squared + 0.0001)\n",
" d = F.clip(self.margin - euclidean_distances, 0, self.margin)\n",
" loss = (1 - label) * distances_squared + label * F.square(d)\n",
" loss = 0.5*loss\n",
" return loss\n",
"loss = ContrastiveLoss(margin=6.0)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "4ac3cfbd",
"metadata": {},
"source": [
"### Define the Siamese network\n",
"A [Siamese network](https://papers.nips.cc/paper/769-signature-verification-using-a-siamese-time-delay-neural-network.pdf) consists of 2 identical networks, that share the same weights. They are trained on pairs of images and each network processes one image. The label defines whether the pair of images is similar or not. The Siamese network learns to differentiate between two input images.\n",
"\n",
"Our network consists of 2 convolutional and max pooling layers that downsample the input image. The output is then fed through a fully connected layer with 256 hidden units and another fully connected layer with 2 hidden units."
]
},
{
"cell_type": "markdown",
"id": "e264c2d5",
"metadata": {},
"source": [
"```python\n",
"class Siamese(gluon.HybridBlock):\n",
" def __init__(self, **kwargs):\n",
" super(Siamese, self).__init__(**kwargs)\n",
" with self.name_scope():\n",
" self.cnn = gluon.nn.HybridSequential()\n",
" with self.cnn.name_scope():\n",
" self.cnn.add(gluon.nn.Conv2D(64, 5, activation='relu'))\n",
" self.cnn.add(gluon.nn.MaxPool2D(2, 2))\n",
" self.cnn.add(gluon.nn.Conv2D(64, 5, activation='relu'))\n",
" self.cnn.add(gluon.nn.MaxPool2D(2, 2))\n",
" self.cnn.add(gluon.nn.Dense(256, activation='relu'))\n",
" self.cnn.add(gluon.nn.Dense(2, activation='softrelu'))\n",
"\n",
" def hybrid_forward(self, F, input0, input1):\n",
" out0 = self.cnn(input0)\n",
" out1 = self.cnn(input1)\n",
" return out0, out1\n",
"\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "52d26cc1",
"metadata": {},
"source": [
"### Prepare the training data\n",
"\n",
"We train our network on the [Ominglot](http://www.omniglot.com/) dataset which is a collection of 1623 hand drawn characters from 50 alphabets. You can download it from [here](https://github.com/brendenlake/omniglot/tree/master/python). We need to create a dataset that contains a random set of similar and dissimilar images. We use Gluon's `ImageFolderDataset` where we overwrite `__getitem__` and randomly return similar and dissimilar pairs of images."
]
},
{
"cell_type": "markdown",
"id": "4edfcc1d",
"metadata": {},
"source": [
"```python\n",
"class GetImagePairs(mx.gluon.data.vision.ImageFolderDataset):\n",
" def __init__(self, root):\n",
" super(GetImagePairs, self).__init__(root, flag=0)\n",
" self.root = root\n",
"\n",
" def __getitem__(self, index):\n",
" items_with_index = list(enumerate(self.items))\n",
" image0_index, image0_tuple = random.choice(items_with_index)\n",
" should_get_same_class = random.randint(0, 1)\n",
" if should_get_same_class:\n",
" while True:\n",
" image1_index, image1_tuple = random.choice(items_with_index)\n",
" if image0_tuple[1] == image1_tuple[1]:\n",
" break\n",
" else:\n",
" image1_index, image1_tuple = random.choice(items_with_index)\n",
" image0 = super().__getitem__(image0_index)\n",
" image1 = super().__getitem__(image1_index)\n",
" label = mx.nd.array([int(image1_tuple[1] != image0_tuple[1])])\n",
" return image0[0], image1[0], label\n",
"\n",
" def __len__(self):\n",
" return super().__len__()\n",
"\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "7a0538c9",
"metadata": {},
"source": [
"We train the network on a subset of the data, the [*Tifinagh*](https://www.omniglot.com/writing/tifinagh.htm) alphabet. Once the model is trained we test it on the [*Inuktitut*](https://www.omniglot.com/writing/inuktitut.htm) alphabet."
]
},
{
"cell_type": "markdown",
"id": "1c3e1082",
"metadata": {},
"source": [
"```python\n",
"def transform(img0, img1, label):\n",
" normalized_img0 = nd.transpose(img0.astype('float32'), (2, 0, 1))/255.0\n",
" normalized_img1 = nd.transpose(img1.astype('float32'), (2, 0, 1))/255.0\n",
" return normalized_img0, normalized_img1, label\n",
"\n",
"training_dir = \"images_background/Tifinagh\"\n",
"testing_dir = \"images_background/Inuktitut_(Canadian_Aboriginal_Syllabics)\"\n",
"train = GetImagePairs(training_dir)\n",
"test = GetImagePairs(testing_dir)\n",
"train_dataloader = gluon.data.DataLoader(train.transform(transform),\n",
" shuffle=True, batch_size=16)\n",
"test_dataloader = gluon.data.DataLoader(test.transform(transform),\n",
" shuffle=False, batch_size=1)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "db4ce984",
"metadata": {},
"source": [
"Following code plots some examples from the test dataset."
]
},
{
"cell_type": "markdown",
"id": "2dfef3e4",
"metadata": {},
"source": [
"```python\n",
"img1, img2, label = test[0]\n",
"print(\"Same: {}\".format(int(label.asscalar()) == 0))\n",
"fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(10, 5))\n",
"ax0.imshow(img1.asnumpy()[:,:,0], cmap='gray')\n",
"ax0.axis('off')\n",
"ax1.imshow(img2.asnumpy()[:,:,0], cmap='gray')\n",
"ax1.axis(\"off\")\n",
"plt.show()\n",
"\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "6c6744c1",
"metadata": {},
"source": [
"![example1](images/inuktitut_1.png)\n",
"\n",
"\n",
"### Train the Siamese network\n",
"\n",
"Before we can start training, we need to instantiate the custom constrastive loss function and initialize the model."
]
},
{
"cell_type": "markdown",
"id": "772ff74d",
"metadata": {},
"source": [
"```python\n",
"model = Siamese()\n",
"model.initialize(init=mx.init.Xavier())\n",
"trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': 0.001})\n",
"loss = ContrastiveLoss(margin=6.0)\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "8e3a8cdd",
"metadata": {},
"source": [
"Start the training loop:"
]
},
{
"cell_type": "markdown",
"id": "8da6ebe9",
"metadata": {},
"source": [
"```python\n",
"for epoch in range(10):\n",
" for i, data in enumerate(train_dataloader):\n",
" image1, image2, label = data\n",
" with autograd.record():\n",
" output1, output2 = model(image1, image2)\n",
" loss_contrastive = loss(output1, output2, label)\n",
" loss_contrastive.backward()\n",
" trainer.step(image1.shape[0])\n",
" loss_mean = loss_contrastive.mean().asscalar()\n",
" print(\"Epoch number {}\\n Current loss {}\\n\".format(epoch, loss_mean))\n",
"\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "a9010151",
"metadata": {},
"source": [
"### Test the trained Siamese network\n",
"During inference we compute the Euclidean distance between the output vectors of the Siamese network. High distance indicates dissimilarity, low values indicate similarity."
]
},
{
"cell_type": "markdown",
"id": "14e48ef9",
"metadata": {},
"source": [
"```python\n",
"for i, data in enumerate(test_dataloader):\n",
" img1, img2, label = data\n",
" output1, output2 = model(img1, img2)\n",
" dist_sq = mx.ndarray.sum(mx.ndarray.square(output1 - output2))\n",
" dist = mx.ndarray.sqrt(dist_sq).asscalar()\n",
" print(\"Euclidean Distance:\", dist, \"Test label\", label[0].asscalar())\n",
" fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(10, 5))\n",
" ax0.imshow(img1.asnumpy()[0, 0, :, :], cmap='gray')\n",
" ax0.axis('off')\n",
" ax1.imshow(img2.asnumpy()[0, 0, :, :], cmap='gray')\n",
" ax1.axis(\"off\")\n",
" plt.show()\n",
"\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "15621a09",
"metadata": {},
"source": [
"![example2](images/inuktitut_2.png)\n",
"\n",
"\n",
"### Common pitfalls with custom loss functions\n",
"\n",
"When customizing loss functions, we may encounter certain pitfalls. If the loss is not decreasing as expected or if forward/backward pass is crashing, then one should check the following:\n",
"\n",
"#### Activation function in the last layer\n",
"Verify whether the last network layer uses the correct activation function: for instance in binary classification tasks we need to apply a sigmoid on the output data. If we use this activation in the last layer and define a loss function like Gluon's SigmoidBinaryCrossEntropy, we would basically apply sigmoid twice and the loss would not converge as expected. If we don't define any activation function, Gluon will per default apply a linear activation.\n",
"\n",
"#### Intermediate loss values\n",
"In our example, we computed the square root of squared distances between 2 images: `F.sqrt(distances_squared)`. If images are very similar we take the sqare root of a value close to 0, which can lead to *NaN* values. Adding a small epsilon to `distances_squared` avoids this problem.\n",
"\n",
"#### Shape of intermediate loss vectors\n",
"In most cases having the wrong tensor shape will lead to an error, as soon as we compare data with labels. But in some cases, we may be able to normally run the training, but it does not converge. For instance, if we don't set `keepdims=True` in our customized loss function, the shape of the tensor changes. The example still runs fine but does not converge.\n",
"\n",
"If you encounter a similar problem, then it is useful to check the tensor shape after each computation step in the loss function.\n",
"\n",
"#### Differentiable\n",
"Backprogration requires the loss function to be differentiable. If the customized loss function cannot be differentiated the backward pass will crash."
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}