blob: 8e4232caa0040fa2f9e5e2de5d2b743ff459ed37 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
import os
import random
import sys
import time
import warnings
import numpy
import numpy as np
import shutil
import logging
warnings.filterwarnings("error")
def timeSince(since=None, s=None):
if s is None:
s = int(time.time() - since)
m = math.floor(s / 60)
s %= 60
h = math.floor(m / 60)
m %= 60
return '%dh %dm %ds' % (h, m, s)
class AvgrageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
import torch
def get_correct_num(y, target):
pred_label = torch.argmax(y, dim=1)
return (target == pred_label).sum().item()
def accuracy(output, target, topk=(1,)):
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class Cutout(object):
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1:y2, x1:x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
import torchvision.transforms as transforms
def _data_transforms_cifar10(args):
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
if args.cutout:
train_transform.transforms.append(Cutout(args.cutout_length))
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
return train_transform, valid_transform
import torchvision.datasets as dset
def _get_cifar10(args):
train_transform, valid_transform = _data_transforms_cifar10(args)
train_data = dset.CIFAR10(
root=args.data, train=True, download=True, transform=train_transform
)
valid_data = dset.CIFAR10(
root=args.data, train=False, download=True, transform=valid_transform
)
train_queue = torch.utils.data.DataLoader(
train_data,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=4,
)
valid_queue = torch.utils.data.DataLoader(
valid_data,
batch_size=args.batch_size,
shuffle=False,
pin_memory=True,
num_workers=4,
)
return train_queue, valid_queue
def _get_dist_cifar10(args):
train_transform, valid_transform = _data_transforms_cifar10(args)
train_data = dset.CIFAR10(
root=args.data, train=True, download=True, transform=train_transform
)
valid_data = dset.CIFAR10(
root=args.data, train=False, download=True, transform=valid_transform
)
sampler = torch.utils.data.distributed.DistributedSampler(
train_data, num_replicas=args.gpu_num, rank=args.local_rank)
train_queue = torch.utils.data.DataLoader(
train_data,
batch_size=args.batch_size // args.gpu_num,
pin_memory=True,
num_workers=4,
drop_last=True,
sampler=sampler
)
valid_queue = torch.utils.data.DataLoader(
valid_data,
batch_size=args.batch_size,
shuffle=False,
pin_memory=True,
num_workers=4,
)
return train_queue, valid_queue, sampler
def _get_dist_imagenet(args):
traindir = os.path.join(args.data_dir, 'train')
valdir = os.path.join(args.data_dir, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = dset.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize,
]))
sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=args.gpu_num, rank=args.local_rank)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size // args.gpu_num, num_workers=max(args.gpu_num * 2, 4),
pin_memory=True, drop_last=True, sampler=sampler)
val_loader = torch.utils.data.DataLoader(
dset.ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=4, pin_memory=True)
return train_loader, val_loader, sampler
def _data_transforms_cifar100(args):
CIFAR_MEAN = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
CIFAR_STD = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
if args.cutout:
train_transform.transforms.append(Cutout(args.cutout_length))
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
return train_transform, valid_transform
def _get_cifar100(args):
train_transform, valid_transform = _data_transforms_cifar100(args)
train_data = dset.CIFAR100(
root=args.data, train=True, download=True, transform=train_transform
)
valid_data = dset.CIFAR100(
root=args.data, train=False, download=True, transform=valid_transform
)
train_queue = torch.utils.data.DataLoader(
train_data,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=4,
)
valid_queue = torch.utils.data.DataLoader(
valid_data,
batch_size=args.batch_size,
shuffle=False,
pin_memory=True,
num_workers=4,
)
return train_queue, valid_queue
def _get_imagenet_tiny(args):
traindir = os.path.join(args.data, 'train')
validdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(
mean=[0.4802, 0.4481, 0.3975],
std=[0.2302, 0.2265, 0.2262]
)
train_transform = transforms.Compose([
transforms.RandomCrop(64, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
if args.cutout:
train_transform.transforms.append(Cutout(args.cutout_length))
train_data = dset.ImageFolder(
traindir,
train_transform
)
valid_data = dset.ImageFolder(
validdir,
transforms.Compose([
transforms.ToTensor(),
normalize,
])
)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.batch_size // 2, shuffle=False, pin_memory=True, num_workers=4)
return train_queue, valid_queue
def count_parameters_in_MB(model):
return np.sum([np.prod(v.size()) for v in model.parameters()]) / 1e6
def count_parameters(model):
"""
Get element number of all parameters matrix.
:param model:
:return:
"""
return sum([torch.numel(v) for v in model.parameters()])
def save(model, model_path):
torch.save(model.state_dict(), model_path)
def load(model, model_path):
model.load_state_dict(torch.load(model_path))
def load_ckpt(ckpt_path):
print(f'=> loading checkpoint {ckpt_path}...')
try:
checkpoint = torch.load(ckpt_path)
except:
print(f"=> fail loading {ckpt_path}...");
exit()
return checkpoint
def save_ckpt(ckpt, file_dir, file_name='model.ckpt', is_best=False):
if not os.path.exists(file_dir): os.makedirs(file_dir)
ckpt_path = os.path.join(file_dir, file_name)
torch.save(ckpt, ckpt_path)
if is_best: shutil.copyfile(ckpt_path, os.path.join(file_dir, f'best_{file_name}'))
def drop_path(x, drop_prob, dims=(0,)):
from torch.autograd import Variable
var_size = [1 for _ in range(x.dim())]
for i in dims:
var_size[i] = x.size(i)
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = Variable(torch.cuda.FloatTensor(*var_size).bernoulli_(keep_prob))
x.div_(keep_prob)
x.mul_(mask)
return x
def create_exp_dir(path, scripts_to_save=None):
if not os.path.exists(path):
os.makedirs(path)
print('Experiment dir : {}'.format(path))
if scripts_to_save is not None:
os.makedirs(os.path.join(path, 'tools'))
for script in scripts_to_save:
dst_file = os.path.join(path, 'tools', os.path.basename(script))
shutil.copyfile(script, dst_file)
class Performance(object):
def __init__(self, path):
self.path = path
self.data = None
def update(self, alphas_normal, alphas_reduce, val_loss):
import torch.nn.functional as F
a_normal = F.softmax(alphas_normal, dim=-1)
# print("alpha normal size: ", a_normal.data.size())
a_reduce = F.softmax(alphas_reduce, dim=-1)
# print("alpha reduce size: ", a_reduce.data.size())
data = np.concatenate([a_normal.data.view(-1),
a_reduce.data.view(-1),
np.array([val_loss.data])]).reshape(1, -1)
if self.data is not None:
self.data = np.concatenate([self.data, data], axis=0)
else:
self.data = data
def save(self):
np.save(self.path, self.data)
def logger(log_dir, need_time=True, need_stdout=False):
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
fh = logging.FileHandler(log_dir)
fh.setLevel(logging.DEBUG)
formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y-%I:%M:%S')
if need_stdout:
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.DEBUG)
log.addHandler(ch)
if need_time:
fh.setFormatter(formatter)
if need_stdout:
ch.setFormatter(formatter)
log.addHandler(fh)
return log
import torch.nn as nn
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def roc_auc_compute_fn(y_pred, y_target):
""" IGNITE.CONTRIB.METRICS.ROC_AUC """
try:
from sklearn.metrics import roc_auc_score
except ImportError:
raise RuntimeError("This contrib module requires sklearn to be installed.")
if y_pred.requires_grad:
y_pred = y_pred.detach()
if y_target.is_cuda:
y_target = y_target.cpu()
if y_pred.is_cuda:
y_pred = y_pred.cpu()
y_true = y_target.numpy()
y_pred = y_pred.numpy()
try:
return roc_auc_score(y_true, y_pred)
except ValueError:
# print('ValueError: Only one class present in y_true. ROC AUC score is not defined in that case.')
return 0.
def load_checkpoint(args):
try:
return torch.load(args.resume)
except RuntimeError:
raise RuntimeError(f"Fail to load checkpoint at {args.resume}")
def save_checkpoint(ckpt, is_best, file_dir, file_name='model.ckpt'):
if not os.path.exists(file_dir):
os.makedirs(file_dir)
ckpt_name = "{0}{1}".format(file_dir, file_name)
torch.save(ckpt, ckpt_name)
if is_best: shutil.copyfile(ckpt_name, "{0}{1}".format(file_dir, 'best_' + file_name))
def seed_everything(seed=2022):
''' [reference] https://gist.github.com/KirillVladimirov/005ec7f762293d2321385580d3dbe335 '''
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True