blob: 4107197acfd8832c2ecd6575b366aeb585e5f748 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# type: ignore
batch_size = 100
num_epochs = 3
momentum = 0.5
log_interval = 100
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import tempfile
import shutil
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
def train_one_epoch(model, data_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(data_loader):
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data),
len(data_loader) * len(data),
100.0 * batch_idx / len(data_loader),
loss.item(),
)
)
def train(learning_rate):
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
print("Running distributed training")
dist.init_process_group("gloo")
temp_dir = tempfile.mkdtemp()
train_dataset = datasets.MNIST(
temp_dir,
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
train_sampler = DistributedSampler(dataset=train_dataset)
data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler
)
model = Net()
ddp_model = DDP(model)
optimizer = optim.SGD(ddp_model.parameters(), lr=learning_rate, momentum=momentum)
for epoch in range(1, num_epochs + 1):
train_one_epoch(ddp_model, data_loader, optimizer, epoch)
dist.destroy_process_group()
shutil.rmtree(temp_dir)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("lr", help="learning_rate", default=0.001)
args = parser.parse_args()
print("learning rate chosen: ", float(args.lr))
train(float(args.lr))