blob: 3ff20afec6d3f1a5072c1386f1d6ce05009d615e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Utils of using msc examples """
import numpy as np
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
def get_dataloaders(path, train_batch=32, test_batch=1, dataset="cifar10"):
"""Get the data loaders for torch process"""
if dataset == "cifar10":
mean = (0.4914, 0.4822, 0.4465)
std = (0.2471, 0.2435, 0.2616)
train_transform = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
trainset = torchvision.datasets.CIFAR10(
root=path, train=True, download=True, transform=train_transform
)
test_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
testset = torchvision.datasets.CIFAR10(
root=path, train=False, download=True, transform=test_transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=train_batch, shuffle=True, num_workers=2
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=test_batch, shuffle=False, num_workers=2
)
return trainloader, testloader
raise Exception("Unexpected dataset " + str(dataset))
def eval_model(model, dataloader, max_iter=-1, log_step=100):
"""Evaluate the model"""
model.eval()
device = next(model.parameters()).device
num_correct, num_datas = 0, 0
for i, (inputs, labels) in enumerate(dataloader, 0):
with torch.no_grad():
outputs = model(inputs.to(device))
cls_idices = torch.argmax(outputs, axis=1)
labels = labels.to(device)
num_datas += len(cls_idices)
num_correct += torch.where(cls_idices == labels, 1, 0).sum()
if num_datas > 0 and num_datas % log_step == 0:
print("[{}/{}] Torch eval acc: {}".format(i, len(dataloader), num_correct / num_datas))
if max_iter > 0 and num_datas >= max_iter:
break
acc = num_correct / num_datas
return acc.detach().cpu().numpy().tolist()
def train_model(model, dataloader, optimizer, max_iter=-1, log_step=100):
"""Train the model"""
model.train()
device = next(model.parameters()).device
num_correct, num_datas = 0, 0
criterion = nn.CrossEntropyLoss()
running_loss = 0.0
for i, (inputs, labels) in enumerate(dataloader, 0):
optimizer.zero_grad()
outputs = model(inputs.to(device))
cls_idices = torch.argmax(outputs, axis=1)
labels = labels.to(device)
num_datas += len(cls_idices)
num_correct += torch.where(cls_idices == labels, 1, 0).sum()
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# gather loss
running_loss += loss.item()
if num_datas > 0 and num_datas % log_step == 0:
print(
"[{}/{}] Torch train loss: {}, acc {}".format(
i, len(dataloader), running_loss / (i + 1), num_correct / num_datas
)
)
if max_iter > 0 and num_datas >= max_iter:
break