| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import mxnet.gluon.data as data |
| |
| from PIL import Image |
| import os |
| import os.path |
| |
| IMG_EXTENSIONS = [ |
| '.jpg', '.JPG', '.jpeg', '.JPEG', |
| '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', |
| ] |
| |
| |
| def is_image_file(filename): |
| return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) |
| |
| |
| def find_classes(dir): |
| classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] |
| classes.sort() |
| class_to_idx = {classes[i]: i for i in range(len(classes))} |
| return classes, class_to_idx |
| |
| |
| def make_dataset(dir, class_to_idx): |
| images = [] |
| dir = os.path.expanduser(dir) |
| for target in sorted(os.listdir(dir)): |
| d = os.path.join(dir, target) |
| if not os.path.isdir(d): |
| continue |
| |
| for root, _, fnames in sorted(os.walk(d)): |
| for fname in sorted(fnames): |
| if is_image_file(fname): |
| path = os.path.join(root, fname) |
| item = (path, class_to_idx[target]) |
| images.append(item) |
| |
| return images |
| |
| |
| def pil_loader(path): |
| # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) |
| with open(path, 'rb') as f: |
| with Image.open(f) as img: |
| return img.convert('RGB') |
| |
| |
| class ImageFolder(data.Dataset): |
| """A generic data loader where the images are arranged in this way: :: |
| |
| root/dog/xxx.png |
| root/dog/xxy.png |
| root/dog/xxz.png |
| |
| root/cat/123.png |
| root/cat/nsdf3.png |
| root/cat/asd932_.png |
| |
| Args: |
| root (string): Root directory path. |
| transform (callable, optional): A function/transform that takes in an PIL image |
| and returns a transformed version. E.g, ``transforms.RandomCrop`` |
| target_transform (callable, optional): A function/transform that takes in the |
| target and transforms it. |
| loader (callable, optional): A function to load an image given its path. |
| |
| Attributes: |
| classes (list): List of the class names. |
| class_to_idx (dict): Dict with items (class_name, class_index). |
| imgs (list): List of (image path, class_index) tuples |
| """ |
| |
| def __init__(self, root, transform=None, target_transform=None, |
| loader=pil_loader): |
| classes, class_to_idx = find_classes(root) |
| imgs = make_dataset(root, class_to_idx) |
| if len(imgs) == 0: |
| raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" |
| "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) |
| |
| self.root = root |
| self.imgs = imgs |
| self.classes = classes |
| self.class_to_idx = class_to_idx |
| self.transform = transform |
| self.target_transform = target_transform |
| self.loader = loader |
| |
| def __getitem__(self, index): |
| """ |
| Args: |
| index (int): Index |
| |
| Returns: |
| tuple: (image, target) where target is class_index of the target class. |
| """ |
| path, target = self.imgs[index] |
| img = self.loader(path) |
| if self.transform is not None: |
| img = self.transform(img) |
| if self.target_transform is not None: |
| target = self.target_transform(target) |
| |
| return img, target |
| |
| def __len__(self): |
| return len(self.imgs) |