blob: cf7ac352bd1ce7a1b64244b16209302be5857c7d [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from torch import tensor
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN
from torchvision.transforms import Compose
from torchvision import transforms
from .imagenet16 import *
def get_dataloader(train_batch_size: int, test_batch_size: int, dataset: str,
num_workers: int, datadir: str, resize=None) -> (DataLoader, DataLoader, int):
"""
Load CIFAR or imagenet datasets
:param train_batch_size:
:param test_batch_size:
:param dataset: ImageNet16, cifar, svhn, ImageNet1k, mnist
:param num_workers:
:param datadir:
:param resize:
:return:
"""
class_num = 0
mean = []
std = []
pad = 0
if 'ImageNet16' in dataset:
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
std = [x / 255 for x in [63.22, 61.26, 65.09]]
size, pad = 16, 2
elif 'cifar' in dataset:
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)
size, pad = 32, 4
elif 'svhn' in dataset:
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
size, pad = 32, 0
elif dataset == 'ImageNet1k':
from .h5py_dataset import H5Dataset
size, pad = 224, 2
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
# resize = 256
elif dataset == 'ImageNet224-120':
from .h5py_dataset import H5Dataset
size, pad = 224, 2
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
# resize = 256
if resize is None:
resize = size
train_transform = transforms.Compose([
transforms.RandomCrop(size, padding=pad),
transforms.Resize(resize),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
test_transform = transforms.Compose([
transforms.Resize((resize, resize)),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
if dataset == 'cifar10':
class_num = 10
train_dataset = CIFAR10(datadir, True, train_transform, download=True)
test_dataset = CIFAR10(datadir, False, test_transform, download=True)
elif dataset == 'cifar100':
class_num = 100
train_dataset = CIFAR100(datadir, True, train_transform, download=True)
test_dataset = CIFAR100(datadir, False, test_transform, download=True)
elif dataset == 'svhn':
class_num = 10
train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True)
test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True)
elif dataset == 'ImageNet16-120':
class_num = 120
train_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), True, train_transform, 120)
test_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), False, test_transform, 120)
elif dataset == 'ImageNet1k':
class_num = 1000
# train_dataset = ImageFolder(root=os.path.join(datadir, 'imagenet/val'), transform=train_transform)
test_dataset = ImageFolder(root=os.path.join(datadir, 'imagenet/val'), transform=test_transform)
train_dataset = test_dataset
elif dataset == 'ImageNet224-120':
class_num = 120
test_dataset = ImageFolder(root=os.path.join(datadir, 'imagenet/val'), transform=test_transform)
# get 0-120 classes
class_indices = list(range(120)) # 0-120 inclusive
subset_indices = [i for i, (_, label) in enumerate(test_dataset.samples) if label in class_indices]
filtered_test_dataset = Subset(test_dataset, subset_indices)
# get 0-120 classes
train_dataset = filtered_test_dataset
test_dataset = filtered_test_dataset
elif dataset == 'mnist':
data_transform = Compose([transforms.ToTensor()])
# Normalise? transforms.Normalize((0.1307,), (0.3081,))
train_dataset = MNIST("_dataset", True, data_transform, download=True)
test_dataset = MNIST("_dataset", False, data_transform, download=True)
else:
raise ValueError('There are no more cifars or imagenets.')
train_loader = DataLoader(
train_dataset,
train_batch_size,
shuffle=True,
num_workers=4,
pin_memory=False
)
test_loader = DataLoader(
test_dataset,
test_batch_size,
shuffle=False,
num_workers=4,
pin_memory=False
)
print("dataset load done")
return train_loader, test_loader, class_num
def get_mini_batch(dataloader: DataLoader, sample_alg: str, batch_size: int, num_classes: int) -> (tensor, tensor):
"""
Get a mini-batch of data,
:param dataloader: DataLoader
:param sample_alg: random or grasp
:param batch_size: batch_size
:param num_classes: num_classes
:return: two tensor
"""
if sample_alg == 'random':
inputs, targets = _get_some_data(dataloader, batch_size=batch_size)
elif sample_alg == 'grasp':
inputs, targets = _get_some_data_grasp(dataloader, num_classes, samples_per_class=batch_size // num_classes)
else:
raise NotImplementedError(f'dataload {sample_alg} is not supported')
return inputs, targets
def _get_some_data(train_dataloader: DataLoader, batch_size: int) -> (torch.tensor, torch.tensor):
"""
Randomly sample some data, some class may not be sampled
:param train_dataloader: torch dataLoader
:param batch_size: batch_size of the data.
:return:
"""
traindata = []
dataloader_iter = iter(train_dataloader)
traindata.append(next(dataloader_iter))
inputs = torch.cat([a for a, _ in traindata])
targets = torch.cat([b for _, b in traindata])
inputs = inputs
targets = targets
return inputs, targets
def _get_some_data_grasp(train_dataloader: DataLoader, num_classes: int,
samples_per_class: int) -> (torch.tensor, torch.tensor):
"""
Sample some data while guarantee example class has equal number of samples.
:param train_dataloader: torch dataLoader
:param num_classes: number of class
:param samples_per_class: how many samples for eacl class.
:return:
"""
datas = [[] for _ in range(num_classes)]
labels = [[] for _ in range(num_classes)]
mark = dict()
dataloader_iter = iter(train_dataloader)
while True:
inputs, targets = next(dataloader_iter)
for idx in range(inputs.shape[0]):
x, y = inputs[idx:idx + 1], targets[idx:idx + 1]
category = y.item()
if len(datas[category]) == samples_per_class:
mark[category] = True
continue
datas[category].append(x)
labels[category].append(y)
if len(mark) == num_classes:
break
x = torch.cat([torch.cat(_, 0) for _ in datas])
y = torch.cat([torch.cat(_) for _ in labels]).view(-1)
return x, y