| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| # pylint: skip-file |
| """ data iterator for mnist """ |
| import os |
| import random |
| import tarfile |
| import logging |
| import tarfile |
| logging.basicConfig(level=logging.INFO) |
| |
| import mxnet as mx |
| from mxnet.test_utils import get_cifar10 |
| from mxnet.gluon.data.vision import ImageFolderDataset |
| from mxnet.gluon.data import DataLoader |
| from mxnet.contrib.io import DataLoaderIter |
| |
| def get_cifar10_iterator(batch_size, data_shape, resize=-1, num_parts=1, part_index=0): |
| get_cifar10() |
| |
| train = mx.io.ImageRecordIter( |
| path_imgrec = "data/cifar/train.rec", |
| # mean_img = "data/cifar/mean.bin", |
| resize = resize, |
| data_shape = data_shape, |
| batch_size = batch_size, |
| rand_crop = True, |
| rand_mirror = True, |
| num_parts=num_parts, |
| part_index=part_index) |
| |
| val = mx.io.ImageRecordIter( |
| path_imgrec = "data/cifar/test.rec", |
| # mean_img = "data/cifar/mean.bin", |
| resize = resize, |
| rand_crop = False, |
| rand_mirror = False, |
| data_shape = data_shape, |
| batch_size = batch_size, |
| num_parts=num_parts, |
| part_index=part_index) |
| |
| return train, val |
| |
| def get_imagenet_transforms(data_shape=224, dtype='float32'): |
| def train_transform(image, label): |
| image, _ = mx.image.random_size_crop(image, (data_shape, data_shape), 0.08, (3/4., 4/3.)) |
| image = mx.nd.image.random_flip_left_right(image) |
| image = mx.nd.image.to_tensor(image) |
| image = mx.nd.image.normalize(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) |
| return mx.nd.cast(image, dtype), label |
| |
| def val_transform(image, label): |
| image = mx.image.resize_short(image, data_shape + 32) |
| image, _ = mx.image.center_crop(image, (data_shape, data_shape)) |
| image = mx.nd.image.to_tensor(image) |
| image = mx.nd.image.normalize(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) |
| return mx.nd.cast(image, dtype), label |
| return train_transform, val_transform |
| |
| def get_imagenet_iterator(root, batch_size, num_workers, data_shape=224, dtype='float32'): |
| """Dataset loader with preprocessing.""" |
| train_dir = os.path.join(root, 'train') |
| train_transform, val_transform = get_imagenet_transforms(data_shape, dtype) |
| logging.info("Loading image folder %s, this may take a bit long...", train_dir) |
| train_dataset = ImageFolderDataset(train_dir).transform_first(train_transform) |
| train_data = DataLoader(train_dataset, batch_size, shuffle=True, |
| last_batch='discard', num_workers=num_workers) |
| val_dir = os.path.join(root, 'val') |
| if not os.path.isdir(os.path.expanduser(os.path.join(root, 'val', 'n01440764'))): |
| user_warning = 'Make sure validation images are stored in one subdir per category, a helper script is available at https://git.io/vNQv1' |
| raise ValueError(user_warning) |
| logging.info("Loading image folder %s, this may take a bit long...", val_dir) |
| val_dataset = ImageFolderDataset(val_dir).transform(val_transform) |
| val_data = DataLoader(val_dataset, batch_size, last_batch='keep', num_workers=num_workers) |
| return DataLoaderIter(train_data, dtype), DataLoaderIter(val_data, dtype) |
| |
| def get_caltech101_data(): |
| url = "https://s3.us-east-2.amazonaws.com/mxnet-public/101_ObjectCategories.tar.gz" |
| dataset_name = "101_ObjectCategories" |
| data_folder = "data" |
| if not os.path.isdir(data_folder): |
| os.makedirs(data_folder) |
| tar_path = mx.gluon.utils.download(url, path=data_folder) |
| if (not os.path.isdir(os.path.join(data_folder, "101_ObjectCategories")) or |
| not os.path.isdir(os.path.join(data_folder, "101_ObjectCategories_test"))): |
| tar = tarfile.open(tar_path, "r:gz") |
| tar.extractall(data_folder) |
| tar.close() |
| print('Data extracted') |
| training_path = os.path.join(data_folder, dataset_name) |
| testing_path = os.path.join(data_folder, "{}_test".format(dataset_name)) |
| return training_path, testing_path |
| |
| def get_caltech101_iterator(batch_size, num_workers, dtype): |
| def transform(image, label): |
| # resize the shorter edge to 224, the longer edge will be greater or equal to 224 |
| resized = mx.image.resize_short(image, 224) |
| # center and crop an area of size (224,224) |
| cropped, crop_info = mx.image.center_crop(resized, (224, 224)) |
| # transpose the channels to be (3,224,224) |
| transposed = mx.nd.transpose(cropped, (2, 0, 1)) |
| return transposed, label |
| |
| training_path, testing_path = get_caltech101_data() |
| dataset_train = ImageFolderDataset(root=training_path).transform(transform) |
| dataset_test = ImageFolderDataset(root=testing_path).transform(transform) |
| |
| train_data = DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers) |
| test_data = DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers) |
| return DataLoaderIter(train_data), DataLoaderIter(test_data) |
| |
| class DummyIter(mx.io.DataIter): |
| def __init__(self, batch_size, data_shape, batches = 100): |
| super(DummyIter, self).__init__(batch_size) |
| self.data_shape = (batch_size,) + data_shape |
| self.label_shape = (batch_size,) |
| self.provide_data = [('data', self.data_shape)] |
| self.provide_label = [('softmax_label', self.label_shape)] |
| self.batch = mx.io.DataBatch(data=[mx.nd.zeros(self.data_shape)], |
| label=[mx.nd.zeros(self.label_shape)]) |
| self._batches = 0 |
| self.batches = batches |
| |
| def next(self): |
| if self._batches < self.batches: |
| self._batches += 1 |
| return self.batch |
| else: |
| self._batches = 0 |
| raise StopIteration |
| |
| def dummy_iterator(batch_size, data_shape): |
| return DummyIter(batch_size, data_shape), DummyIter(batch_size, data_shape) |
| |
| class ImagePairIter(mx.io.DataIter): |
| def __init__(self, path, data_shape, label_shape, batch_size=64, flag=0, input_aug=None, target_aug=None): |
| super(ImagePairIter, self).__init__(batch_size) |
| self.data_shape = (batch_size,) + data_shape |
| self.label_shape = (batch_size,) + label_shape |
| self.input_aug = input_aug |
| self.target_aug = target_aug |
| self.provide_data = [('data', self.data_shape)] |
| self.provide_label = [('label', self.label_shape)] |
| is_image_file = lambda fn: any(fn.endswith(ext) for ext in [".png", ".jpg", ".jpeg"]) |
| self.filenames = [os.path.join(path, x) for x in os.listdir(path) if is_image_file(x)] |
| self.count = 0 |
| self.flag = flag |
| random.shuffle(self.filenames) |
| |
| def next(self): |
| from PIL import Image |
| if self.count + self.batch_size <= len(self.filenames): |
| data = [] |
| label = [] |
| for i in range(self.batch_size): |
| fn = self.filenames[self.count] |
| self.count += 1 |
| image = Image.open(fn).convert('YCbCr').split()[0] |
| if image.size[0] > image.size[1]: |
| image = image.transpose(Image.TRANSPOSE) |
| image = mx.np.expand_dims(mx.np.array(image), axis=2) |
| target = image.copy() |
| for aug in self.input_aug: |
| image = aug(image) |
| for aug in self.target_aug: |
| target = aug(target) |
| data.append(image) |
| label.append(target) |
| |
| data = mx.np.concatenate([mx.np.expand_dims(d, axis=0) for d in data], axis=0) |
| label = mx.np.concatenate([mx.np.expand_dims(d, axis=0) for d in label], axis=0) |
| data = [mx.np.transpose(data, axes=(0, 3, 1, 2)).astype('float32')/255] |
| label = [mx.np.transpose(label, axes=(0, 3, 1, 2)).astype('float32')/255] |
| |
| return mx.io.DataBatch(data=data, label=label) |
| else: |
| raise StopIteration |
| |
| def reset(self): |
| self.count = 0 |
| random.shuffle(self.filenames) |