blob: 7769f605cc4750b5d2975175b4e33f84425a326a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
""" data iterator for mnist """
import os
import random
import tarfile
import logging
import tarfile
logging.basicConfig(level=logging.INFO)
import mxnet as mx
from mxnet.test_utils import get_cifar10
from mxnet.gluon.data.vision import ImageFolderDataset
from mxnet.gluon.data import DataLoader
from mxnet.contrib.io import DataLoaderIter
def get_cifar10_iterator(batch_size, data_shape, resize=-1, num_parts=1, part_index=0):
get_cifar10()
train = mx.io.ImageRecordIter(
path_imgrec = "data/cifar/train.rec",
# mean_img = "data/cifar/mean.bin",
resize = resize,
data_shape = data_shape,
batch_size = batch_size,
rand_crop = True,
rand_mirror = True,
num_parts=num_parts,
part_index=part_index)
val = mx.io.ImageRecordIter(
path_imgrec = "data/cifar/test.rec",
# mean_img = "data/cifar/mean.bin",
resize = resize,
rand_crop = False,
rand_mirror = False,
data_shape = data_shape,
batch_size = batch_size,
num_parts=num_parts,
part_index=part_index)
return train, val
def get_imagenet_transforms(data_shape=224, dtype='float32'):
def train_transform(image, label):
image, _ = mx.image.random_size_crop(image, (data_shape, data_shape), 0.08, (3/4., 4/3.))
image = mx.nd.image.random_flip_left_right(image)
image = mx.nd.image.to_tensor(image)
image = mx.nd.image.normalize(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
return mx.nd.cast(image, dtype), label
def val_transform(image, label):
image = mx.image.resize_short(image, data_shape + 32)
image, _ = mx.image.center_crop(image, (data_shape, data_shape))
image = mx.nd.image.to_tensor(image)
image = mx.nd.image.normalize(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
return mx.nd.cast(image, dtype), label
return train_transform, val_transform
def get_imagenet_iterator(root, batch_size, num_workers, data_shape=224, dtype='float32'):
"""Dataset loader with preprocessing."""
train_dir = os.path.join(root, 'train')
train_transform, val_transform = get_imagenet_transforms(data_shape, dtype)
logging.info("Loading image folder %s, this may take a bit long...", train_dir)
train_dataset = ImageFolderDataset(train_dir).transform_first(train_transform)
train_data = DataLoader(train_dataset, batch_size, shuffle=True,
last_batch='discard', num_workers=num_workers)
val_dir = os.path.join(root, 'val')
if not os.path.isdir(os.path.expanduser(os.path.join(root, 'val', 'n01440764'))):
user_warning = 'Make sure validation images are stored in one subdir per category, a helper script is available at https://git.io/vNQv1'
raise ValueError(user_warning)
logging.info("Loading image folder %s, this may take a bit long...", val_dir)
val_dataset = ImageFolderDataset(val_dir).transform(val_transform)
val_data = DataLoader(val_dataset, batch_size, last_batch='keep', num_workers=num_workers)
return DataLoaderIter(train_data, dtype), DataLoaderIter(val_data, dtype)
def get_caltech101_data():
url = "https://s3.us-east-2.amazonaws.com/mxnet-public/101_ObjectCategories.tar.gz"
dataset_name = "101_ObjectCategories"
data_folder = "data"
if not os.path.isdir(data_folder):
os.makedirs(data_folder)
tar_path = mx.gluon.utils.download(url, path=data_folder)
if (not os.path.isdir(os.path.join(data_folder, "101_ObjectCategories")) or
not os.path.isdir(os.path.join(data_folder, "101_ObjectCategories_test"))):
tar = tarfile.open(tar_path, "r:gz")
tar.extractall(data_folder)
tar.close()
print('Data extracted')
training_path = os.path.join(data_folder, dataset_name)
testing_path = os.path.join(data_folder, "{}_test".format(dataset_name))
return training_path, testing_path
def get_caltech101_iterator(batch_size, num_workers, dtype):
def transform(image, label):
# resize the shorter edge to 224, the longer edge will be greater or equal to 224
resized = mx.image.resize_short(image, 224)
# center and crop an area of size (224,224)
cropped, crop_info = mx.image.center_crop(resized, (224, 224))
# transpose the channels to be (3,224,224)
transposed = mx.nd.transpose(cropped, (2, 0, 1))
return transposed, label
training_path, testing_path = get_caltech101_data()
dataset_train = ImageFolderDataset(root=training_path).transform(transform)
dataset_test = ImageFolderDataset(root=testing_path).transform(transform)
train_data = DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers)
test_data = DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers)
return DataLoaderIter(train_data), DataLoaderIter(test_data)
class DummyIter(mx.io.DataIter):
def __init__(self, batch_size, data_shape, batches = 100):
super(DummyIter, self).__init__(batch_size)
self.data_shape = (batch_size,) + data_shape
self.label_shape = (batch_size,)
self.provide_data = [('data', self.data_shape)]
self.provide_label = [('softmax_label', self.label_shape)]
self.batch = mx.io.DataBatch(data=[mx.nd.zeros(self.data_shape)],
label=[mx.nd.zeros(self.label_shape)])
self._batches = 0
self.batches = batches
def next(self):
if self._batches < self.batches:
self._batches += 1
return self.batch
else:
self._batches = 0
raise StopIteration
def dummy_iterator(batch_size, data_shape):
return DummyIter(batch_size, data_shape), DummyIter(batch_size, data_shape)
class ImagePairIter(mx.io.DataIter):
def __init__(self, path, data_shape, label_shape, batch_size=64, flag=0, input_aug=None, target_aug=None):
super(ImagePairIter, self).__init__(batch_size)
self.data_shape = (batch_size,) + data_shape
self.label_shape = (batch_size,) + label_shape
self.input_aug = input_aug
self.target_aug = target_aug
self.provide_data = [('data', self.data_shape)]
self.provide_label = [('label', self.label_shape)]
is_image_file = lambda fn: any(fn.endswith(ext) for ext in [".png", ".jpg", ".jpeg"])
self.filenames = [os.path.join(path, x) for x in os.listdir(path) if is_image_file(x)]
self.count = 0
self.flag = flag
random.shuffle(self.filenames)
def next(self):
from PIL import Image
if self.count + self.batch_size <= len(self.filenames):
data = []
label = []
for i in range(self.batch_size):
fn = self.filenames[self.count]
self.count += 1
image = Image.open(fn).convert('YCbCr').split()[0]
if image.size[0] > image.size[1]:
image = image.transpose(Image.TRANSPOSE)
image = mx.np.expand_dims(mx.np.array(image), axis=2)
target = image.copy()
for aug in self.input_aug:
image = aug(image)
for aug in self.target_aug:
target = aug(target)
data.append(image)
label.append(target)
data = mx.np.concatenate([mx.np.expand_dims(d, axis=0) for d in data], axis=0)
label = mx.np.concatenate([mx.np.expand_dims(d, axis=0) for d in label], axis=0)
data = [mx.np.transpose(data, axes=(0, 3, 1, 2)).astype('float32')/255]
label = [mx.np.transpose(label, axes=(0, 3, 1, 2)).astype('float32')/255]
return mx.io.DataBatch(data=data, label=label)
else:
raise StopIteration
def reset(self):
self.count = 0
random.shuffle(self.filenames)