blob: 634c78ea113d5375470f5c83415d7b086c07e255 [file] [log] [blame]
{"nbformat": 4, "cells": [{"source": "<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n<!--- or more contributor license agreements. See the NOTICE file -->\n<!--- distributed with this work for additional information -->\n<!--- regarding copyright ownership. The ASF licenses this file -->\n<!--- to you under the Apache License, Version 2.0 (the -->\n<!--- \"License\"); you may not use this file except in compliance -->\n<!--- with the License. You may obtain a copy of the License at -->\n\n<!--- http://www.apache.org/licenses/LICENSE-2.0 -->\n\n<!--- Unless required by applicable law or agreed to in writing, -->\n<!--- software distributed under the License is distributed on an -->\n<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n<!--- KIND, either express or implied. See the License for the -->\n<!--- specific language governing permissions and limitations -->\n<!--- under the License. -->\n\n\n# Data Augmentation with Masks\n\n## Data Augmentation\n\nData Augmentation is a regularization technique that's used to avoid overfitting when training Machine Learning models. Although the technique can be applied in a variety of domains, it's very common in Computer Vision, and this will be the focus of the tutorial. Adjustments are made to the original images in the training dataset before being used in training. Some example adjustments include translating, croping, scaling, rotating, changing brightness and contrast. We do this to reduce the dependence of the model on spurious characteristics; e.g. training data may only contain faces that fill 1/4 of the image, so the model trainied without data augmentation might unhelpfully learn that faces can only be of this size.\n\n<img src=\"https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/data_aug/outputs/with_mask/orig_vs_aug.png\" alt=\"Drawing\" style=\"width: 480px;\"/>\n\n## Masks\n\nCertain Computer Vision tasks (like [Object Segmentation](https://arxiv.org/abs/1506.06204)) require the use of 'masks', and we have to take extra care when using these in conjunction with data augmentation techniques. Given an underlying base image (with 3 channels), a masking channel can be added to provide additional metadata to certain regions of the base image. Masking channels often contain binary values, and these can be used to label a single class, e.g. to label a dog in the foreground. Multi-class segmentation problems could use many binary masking channels (i.e. one binary channel per class), but it is more common to see RGB representations, where each class is a different color. We take an example from the [COCO dataset](http://cocodataset.org/).\n\n<img src=\"https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/data_aug/outputs/with_mask/masks.png\" alt=\"Drawing\" style=\"width: 700px;\"/>\n\n## Data Augmentation with Masks\n\nWhen we adjust the position of the base image as part of data augmentation, we also need to apply exactly the same operation to the associated masks. An example would be after applying a horizontal flip to the base image, we'd need to also flip the mask, to preserve the corresponsence between the base image and mask.\n\nColor changes to the base image don't need to be applied to the segmentation masks though; and may even lead to errors with the masks. An example with a RGB mask, would be accidentally converting a region of green for dog to blue for cat.\n\n<img src=\"https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/data_aug/outputs/with_mask/aug_and_mask.png\" alt=\"Drawing\" style=\"width: 800px;\"/>\n\n# Custom Dataset\n\nWith Gluon it's easy to work with different types of data. You can write custom Datasets and plug them directly into a DataLoader which will handle batching. Segmentation tasks are structured in such a way that the data is the base image and the label is the mask, so we will create a custom Dataset for this. Our Dataset will return base images with their corresponsing masks.\n\nIt will be based on the `mx.gluon.data.vision.ImageFolderDataset` for simplicity, and will load files from a single folder, containing images of the form `xyz.jpg` and their corresponsing mask `xyz_mask.png`.\n\n`__getitem__` must be implemented, as this will be used by the DataLoader.", "cell_type": "markdown", "metadata": {}}, {"source": "%matplotlib inline\nimport collections\nimport mxnet as mx # used version '1.0.0' at time of writing\nfrom mxnet.gluon.data import dataset\nimport os\nimport numpy as np\nfrom matplotlib.pyplot import imshow\nimport matplotlib.pyplot as plt\n\nmx.random.seed(42) # set seed for repeatability\n\n\nclass ImageWithMaskDataset(dataset.Dataset):\n \"\"\"\n A dataset for loading images (with masks) stored as `xyz.jpg` and `xyz_mask.png`.\n\n Parameters\n ----------\n root : str\n Path to root directory.\n transform : callable, default None\n A function that takes data and label and transforms them:\n ::\n transform = lambda data, label: (data.astype(np.float32)/255, label)\n \"\"\"\n def __init__(self, root, transform=None):\n self._root = os.path.expanduser(root)\n self._transform = transform\n self._exts = ['.jpg', '.jpeg', '.png']\n self._list_images(self._root)\n\n def _list_images(self, root):\n images = collections.defaultdict(dict)\n for filename in sorted(os.listdir(root)):\n name, ext = os.path.splitext(filename)\n mask_flag = name.endswith(\"_mask\")\n if ext.lower() not in self._exts:\n continue\n if not mask_flag:\n images[name][\"base\"] = filename\n else:\n name = name[:-5] # to remove '_mask'\n images[name][\"mask\"] = filename\n self._image_list = list(images.values())\n\n def __getitem__(self, idx):\n assert 'base' in self._image_list[idx], \"Couldn't find base image for: \" + image_list[idx][\"mask\"]\n base_filepath = os.path.join(self._root, self._image_list[idx][\"base\"])\n base = mx.image.imread(base_filepath)\n assert 'mask' in self._image_list[idx], \"Couldn't find mask image for: \" + image_list[idx][\"base\"]\n mask_filepath = os.path.join(self._root, self._image_list[idx][\"mask\"])\n mask = mx.image.imread(mask_filepath)\n if self._transform is not None:\n return self._transform(base, mask)\n else:\n return base, mask\n\n def __len__(self):\n return len(self._image_list)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Using our Dataset \n\nUsually Datasets are used in conjunction with DataLoaders, but we'll sample a single base image and mask pair for testing purposes. Calling `dataset[0]` (which is equivalent to `dataset.__getitem__(0)`) returns the first base image and mask pair from the `_image_list`. At first download the sample images and then we'll load them without any augmentation.", "cell_type": "markdown", "metadata": {}}, {"source": "image_dir = os.path.join(\"data\", \"images\")\nmx.test_utils.download('https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/data_aug/inputs/0.jpg', dirname=image_dir)\nmx.test_utils.download('https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/data_aug/inputs/0_mask.png', dirname=image_dir)\ndataset = ImageWithMaskDataset(root=image_dir)\nsample = dataset.__getitem__(0)\nsample_base = sample[0].astype('float32')\nsample_mask = sample[1].astype('float32')\nassert sample_base.shape == (427, 640, 3)\nassert sample_mask.shape == (427, 640, 3)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "def plot_mx_arrays(arrays):\n \"\"\"\n Array expected to be height x width x 3 (channels), and values are floats between 0 and 255.\n \"\"\"\n plt.subplots(figsize=(12, 4))\n for idx, array in enumerate(arrays):\n assert array.shape[2] == 3, \"RGB Channel should be last\"\n plt.subplot(1, 2, idx+1)\n imshow((array.clip(0, 255)/255).asnumpy())", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "plot_mx_arrays([sample_base, sample_mask])", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "# Implementing `transform` for Augmentation\n\nWe now construct our augmentation pipeline by implementing a transform function. Given a data sample and its corresponding label, this function must also return data and a label. In our specific example, our transform function will take the base image and corresponding mask, and return the augmented base image and correctly augmented mask. We will provide this to the `ImageWithMaskDataset` via the `transform` argument, and it will be applied to each sample (i.e. each data and label pair).\n\nOur approach is to apply positional augmentations to the combined base image and mask, and then apply the color augmentations to the positionally augmented base image only. We concatenate the base image with the mask along the channels dimension. So if we have a 3 channel base image, and a 3 channel mask, the result will be a 6 channel array. After applying positional augmentations on this array, we split out the base image and mask once again. Our last step is to apply the colour augmentation to just the augmented base image.", "cell_type": "markdown", "metadata": {}}, {"source": "def positional_augmentation(joint):\n # Random crop\n crop_height = 200\n crop_width = 200\n aug = mx.image.RandomCropAug(size=(crop_width, crop_height)) # Watch out: weight before height in size param!\n aug_joint = aug(joint)\n # Deterministic resize\n resize_size = 100\n aug = mx.image.ResizeAug(resize_size)\n aug_joint = aug(aug_joint)\n # Add more translation/scale/rotation augmentations here...\n return aug_joint\n\n\ndef color_augmentation(base):\n # Only applied to the base image, and not the mask layers.\n aug = mx.image.BrightnessJitterAug(brightness=0.2)\n aug_base = aug(base)\n # Add more color augmentations here...\n return aug_base\n\n\ndef joint_transform(base, mask):\n ### Convert types\n base = base.astype('float32')/255\n mask = mask.astype('float32')/255\n \n ### Join\n # Concatinate on channels dim, to obtain an 6 channel image\n # (3 channels for the base image, plus 3 channels for the mask)\n base_channels = base.shape[2] # so we know where to split later on\n joint = mx.nd.concat(base, mask, dim=2)\n\n ### Augmentation Part 1: positional\n aug_joint = positional_augmentation(joint)\n \n ### Split\n aug_base = aug_joint[:, :, :base_channels]\n aug_mask = aug_joint[:, :, base_channels:]\n \n ### Augmentation Part 2: color\n aug_base = color_augmentation(aug_base)\n\n return aug_base, aug_mask", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "# Using Augmentation \n\nIt's simple to use augmentation now that we have the `joint_transform` function defined. Simply set the `tranform` argument when defining the Dataset. You'll notice the alignment between the base image and the mask is preserved, and the mask colors are left unchanged.", "cell_type": "markdown", "metadata": {}}, {"source": "image_dir = os.path.join(\"data\",\"images\")\nds = ImageWithMaskDataset(root=image_dir, transform=joint_transform)\nsample = ds.__getitem__(0)\nassert len(sample) == 2\nassert sample[0].shape == (100, 100, 3)\nassert sample[1].shape == (100, 100, 3)\nplot_mx_arrays([sample[0]*255, sample[1]*255])", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n\n\n\n# Summary\n\nWe've succesfully created a custom Dataset for images and corresponding masks, implemented an augmentation `transform` function that correctly handles masks, and applied it to each sample of the Dataset. You're now ready to train your own object segmentation models!\n\n# Appendix (COCO Dataset)\n\n[COCO dataset](http://cocodataset.org/) is a great resource for image segmentation data. It contains over 200k labelled images, with over 1.5 million object instances across 80 object categories. You can download the data using `gsutil` as per the instuctions below (from http://cocodataset.org/#download):\n\n### 1) Install `gsutil`\n\n```\ncurl https://sdk.cloud.google.com | bash\n```\n\n### 2) Download Images\n\nWe download the validation data from 2017 from `gs://images.cocodataset.org/val2017` as an example. It's a much more manageable size (~770MB) compared to the test and training data with are both > 5GB.\n\n```\nmkdir coco_data\nmkdir coco_data/images\ngsutil -m rsync gs://images.cocodataset.org/val2017 coco_data/images\n```\n\n### 3) Download Masks (a.k.a. pixel maps) \n\n```\ngsutil -m cp gs://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip \\\n coco_data/stuff_annotations_trainval2017.zip\nunzip coco_data/stuff_annotations_trainval2017.zip\nrm coco_data/stuff_annotations_trainval2017.zip\nunzip annotations/stuff_val2017_pixelmaps.zip\nrm -r annotations\nmkdir coco_data/masks\nmv -v stuff_val2017_pixelmaps/* coco_data/masks/\nrm -r stuff_val2017_pixelmaps\n```\n\n<!-- INSERT SOURCE DOWNLOAD BUTTONS -->", "cell_type": "markdown", "metadata": {}}], "metadata": {"display_name": "", "name": "", "language": "python"}, "nbformat_minor": 2}