blob: c2c45a705c78b5343397dcda1968341a74fce775 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
import os
import random
import sys
import time
import warnings
import numpy
import numpy as np
import torch
import shutil
import logging
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision.datasets as dset
import torch.nn as nn
warnings.filterwarnings("error")
def timeSince(since=None, s=None):
if s is None:
s = int(time.time() - since)
m = math.floor(s / 60)
s %= 60
h = math.floor(m / 60)
m %= 60
return '%dh %dm %ds' % (h, m, s)
class AvgrageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_correct_num(y, target):
pred_label = torch.argmax(y, dim=1)
return (target == pred_label).sum().item()
def accuracy(output, target, topk=(1,)):
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class Cutout(object):
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1:y2, x1:x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
def _data_transforms_cifar10(args):
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
if args.cutout:
train_transform.transforms.append(Cutout(args.cutout_length))
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
return train_transform, valid_transform
def _get_cifar10(args):
train_transform, valid_transform = _data_transforms_cifar10(args)
train_data = dset.CIFAR10(
root=args.data, train=True, download=True, transform=train_transform
)
valid_data = dset.CIFAR10(
root=args.data, train=False, download=True, transform=valid_transform
)
train_queue = torch.utils.data.DataLoader(
train_data,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=4,
)
valid_queue = torch.utils.data.DataLoader(
valid_data,
batch_size=args.batch_size,
shuffle=False,
pin_memory=True,
num_workers=4,
)
return train_queue, valid_queue
def _get_dist_cifar10(args):
train_transform, valid_transform = _data_transforms_cifar10(args)
train_data = dset.CIFAR10(
root=args.data, train=True, download=True, transform=train_transform
)
valid_data = dset.CIFAR10(
root=args.data, train=False, download=True, transform=valid_transform
)
sampler = torch.utils.data.distributed.DistributedSampler(
train_data, num_replicas=args.gpu_num, rank=args.local_rank)
train_queue = torch.utils.data.DataLoader(
train_data,
batch_size=args.batch_size // args.gpu_num,
pin_memory=True,
num_workers=4,
drop_last=True,
sampler=sampler
)
valid_queue = torch.utils.data.DataLoader(
valid_data,
batch_size=args.batch_size,
shuffle=False,
pin_memory=True,
num_workers=4,
)
return train_queue, valid_queue, sampler
def _get_dist_imagenet(args):
traindir = os.path.join(args.data_dir, 'train')
valdir = os.path.join(args.data_dir, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = dset.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize,
]))
sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=args.gpu_num, rank=args.local_rank)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size // args.gpu_num, num_workers=max(args.gpu_num * 2, 4),
pin_memory=True, drop_last=True, sampler=sampler)
val_loader = torch.utils.data.DataLoader(
dset.ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=4, pin_memory=True)
return train_loader, val_loader, sampler
def _data_transforms_cifar100(args):
CIFAR_MEAN = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
CIFAR_STD = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
if args.cutout:
train_transform.transforms.append(Cutout(args.cutout_length))
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
return train_transform, valid_transform
def _get_cifar100(args):
train_transform, valid_transform = _data_transforms_cifar100(args)
train_data = dset.CIFAR100(
root=args.data, train=True, download=True, transform=train_transform
)
valid_data = dset.CIFAR100(
root=args.data, train=False, download=True, transform=valid_transform
)
train_queue = torch.utils.data.DataLoader(
train_data,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=4,
)
valid_queue = torch.utils.data.DataLoader(
valid_data,
batch_size=args.batch_size,
shuffle=False,
pin_memory=True,
num_workers=4,
)
return train_queue, valid_queue
def _get_imagenet_tiny(args):
traindir = os.path.join(args.data, 'train')
validdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(
mean=[0.4802, 0.4481, 0.3975],
std=[0.2302, 0.2265, 0.2262]
)
train_transform = transforms.Compose([
transforms.RandomCrop(64, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
if args.cutout:
train_transform.transforms.append(Cutout(args.cutout_length))
train_data = dset.ImageFolder(
traindir,
train_transform
)
valid_data = dset.ImageFolder(
validdir,
transforms.Compose([
transforms.ToTensor(),
normalize,
])
)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.batch_size // 2, shuffle=False, pin_memory=True, num_workers=4)
return train_queue, valid_queue
def count_parameters_in_MB(model):
return np.sum([np.prod(v.size()) for v in model.parameters()]) / 1e6
def count_parameters(model):
"""
Get element number of all parameters matrix.
:param model:
:return:
"""
return sum([torch.numel(v) for v in model.parameters()])
def save(model, model_path):
torch.save(model.state_dict(), model_path)
def load(model, model_path):
model.load_state_dict(torch.load(model_path))
def load_ckpt(ckpt_path):
print(f'=> loading checkpoint {ckpt_path}...')
try:
checkpoint = torch.load(ckpt_path)
except:
print(f"=> fail loading {ckpt_path}...");
exit()
return checkpoint
def save_ckpt(ckpt, file_dir, file_name='model.ckpt', is_best=False):
if not os.path.exists(file_dir): os.makedirs(file_dir)
ckpt_path = os.path.join(file_dir, file_name)
torch.save(ckpt, ckpt_path)
if is_best: shutil.copyfile(ckpt_path, os.path.join(file_dir, f'best_{file_name}'))
def drop_path(x, drop_prob, dims=(0,)):
var_size = [1 for _ in range(x.dim())]
for i in dims:
var_size[i] = x.size(i)
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = Variable(torch.cuda.FloatTensor(*var_size).bernoulli_(keep_prob))
x.div_(keep_prob)
x.mul_(mask)
return x
def create_exp_dir(path, scripts_to_save=None):
if not os.path.exists(path):
os.makedirs(path)
print('Experiment dir : {}'.format(path))
if scripts_to_save is not None:
os.makedirs(os.path.join(path, 'tools'))
for script in scripts_to_save:
dst_file = os.path.join(path, 'tools', os.path.basename(script))
shutil.copyfile(script, dst_file)
class Performance(object):
def __init__(self, path):
self.path = path
self.data = None
def update(self, alphas_normal, alphas_reduce, val_loss):
a_normal = F.softmax(alphas_normal, dim=-1)
# print("alpha normal size: ", a_normal.data.size())
a_reduce = F.softmax(alphas_reduce, dim=-1)
# print("alpha reduce size: ", a_reduce.data.size())
data = np.concatenate([a_normal.data.view(-1),
a_reduce.data.view(-1),
np.array([val_loss.data])]).reshape(1, -1)
if self.data is not None:
self.data = np.concatenate([self.data, data], axis=0)
else:
self.data = data
def save(self):
np.save(self.path, self.data)
def logger(log_dir, need_time=True, need_stdout=False):
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
fh = logging.FileHandler(log_dir)
fh.setLevel(logging.DEBUG)
formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y-%I:%M:%S')
if need_stdout:
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.DEBUG)
log.addHandler(ch)
if need_time:
fh.setFormatter(formatter)
if need_stdout:
ch.setFormatter(formatter)
log.addHandler(fh)
return log
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def roc_auc_compute_fn(y_pred, y_target):
""" IGNITE.CONTRIB.METRICS.ROC_AUC """
try:
from sklearn.metrics import roc_auc_score
except ImportError:
raise RuntimeError("This contrib module requires sklearn to be installed.")
if y_pred.requires_grad:
y_pred = y_pred.detach()
if y_target.is_cuda:
y_target = y_target.cpu()
if y_pred.is_cuda:
y_pred = y_pred.cpu()
y_true = y_target.numpy()
y_pred = y_pred.numpy()
try:
return roc_auc_score(y_true, y_pred)
except ValueError:
# print('ValueError: Only one class present in y_true. ROC AUC score is not defined in that case.')
return 0.
def load_checkpoint(args):
try:
return torch.load(args.resume)
except RuntimeError:
raise RuntimeError(f"Fail to load checkpoint at {args.resume}")
def save_checkpoint(ckpt, is_best, file_dir, file_name='model.ckpt'):
if not os.path.exists(file_dir):
os.makedirs(file_dir)
ckpt_name = "{0}{1}".format(file_dir, file_name)
torch.save(ckpt, ckpt_name)
if is_best: shutil.copyfile(ckpt_name, "{0}{1}".format(file_dir, 'best_' + file_name))
def seed_everything(seed=2022):
''' [reference] https://gist.github.com/KirillVladimirov/005ec7f762293d2321385580d3dbe335 '''
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True