from os import stat
import os
import time
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, sampler
from torch.optim.lr_scheduler import StepLR
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
import shutil
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x)
return x
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for idx, (images, targets) in enumerate(train_loader):
images, targets = images.to(device), targets.to(device)
pred = model(images)
loss = F.cross_entropy(pred, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print("===>local_rank:{}".format(args.local_rank))
if idx % args.log_interval == 0 and args.local_rank == 0:
print("Train Time:{}, epoch: {}, step: {}, loss: {}".format(time.strftime("%Y-%m-%d%H:%M:%S"), epoch + 1, idx, loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
test_acc = 0
with torch.no_grad():
for (images, targets) in test_loader:
images, targets = images.to(device), targets.to(device)
pred = model(images)
loss = F.cross_entropy(pred, targets, reduction="sum")
test_loss += loss.item()
pred_label = torch.argmax(pred, dim=1, keepdims=True)
test_acc += pred_label.eq(targets.view_as(pred_label)).sum().item()
test_loss /= len(test_loader.dataset)
test_acc /= len(test_loader.dataset)
print("Test Time:{}, loss: {}, acc: {}".format(time.strftime("%Y-%m-%d%H:%M:%S"), test_loss, test_acc))
return test_acc
def save_checkpoint(state, is_best, filename = 'checkpoint.pth.tar'):
torch.save(state, filename)
print("===> save state to {}\n".format(filename))
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
def main():
parser = argparse.ArgumentParser(description="MNIST TRAINING")
parser.add_argument('--device_ids', type=str, default='0', help="Training Devices")
parser.add_argument('--epochs', type=int, default=10, help="Training Epoch")
parser.add_argument('--log_interval', type=int, default=100, help="Log Interval")
parser.add_argument('--resume', type=str, default="/home/ubuntu/suyunzheng_ws/mnist/code/checkpoint.pth.tar", help="checkpoint resume path")
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameterm do not modify') # 注: 这里如果使用了argparse, 一定要在参数里面加上--local_rank, 否则运行还是会出错的
args = parser.parse_args()
device_ids = list(map(int, args.device_ids.split(',')))
dist.init_process_group(backend='nccl')
device = torch.device('cuda:{}'.format(device_ids[args.local_rank])) # 不同的进程(GPU)这个是不同的
print("===> devcice:{}\n".format(device))
torch.cuda.set_device(device) # 设置当前设备
model = Net().to(device)
model = DistributedDataParallel(module=model, device_ids=[device_ids[args.local_rank]], output_device=device_ids[args.local_rank], find_unused_parameters=True) # 忽略有,但是没有使用的参数
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307), (0.3081))])
dataset_train = datasets.MNIST('/home/ubuntu/suyunzheng_ws/mnist/data', train=True, transform=transform, download=True)
dataset_test = datasets.MNIST('/home/ubuntu/suyunzheng_ws/mnist/data', train=False, transform=transform, download=True)
sampler_train = DistributedSampler(dataset=dataset_train, shuffle=True)
train_loader = DataLoader(dataset_train, batch_size=8, num_workers=8, sampler=sampler_train) # DDP batch_size means per GPU's batch_size, while DP batch_size means all GPU's batch_size
test_loader = DataLoader(dataset_test, batch_size=8, shuffle=False, num_workers=8)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = StepLR(optimizer, step_size=1)
start_epoch = 0
best_acc = 0
if args.resume and args.local_rank == 0:
print("===> resume, ..., device:{}".format(device))
if os.path.isfile(args.resume):
print("===> loading checkpoint :{}".format(args.resume))
# loc = 'cuda:{}'.format(0)
loc = 'cuda:{}'.format(device_ids[args.local_rank])
checkpoint = torch.load(args.resume, map_location=loc)
start_epoch = checkpoint['epoch']
best_acc = checkpoint['best_acc']
# best_acc =best_acc.to(loc)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("===> loaded checkpoint {} (epoch {}, best_acc {})".format(args.resume, checkpoint['epoch'], checkpoint['best_acc']))
for start_epoch in range(args.epochs):
sampler_train.set_epoch(epoch=start_epoch) # 设置不同GPU的数据是变换的,保证更加shuffle
train(args, model, device, train_loader, optimizer, start_epoch)
scheduler.step()
if args.local_rank == 0: # local_rank有点像不同的进程
acc = test(args, model, device, test_loader)
is_best = acc>best_acc
best_acc = max(acc, best_acc)
save_checkpoint(
{
'epoch': start_epoch+1,
'state_dict': model.state_dict(),
'best_acc': best_acc,
'optimizer': optimizer.state_dict(),
},
is_best=is_best
)
# if args.local_rank == 0:
# torch.save(model.state_dict(), 'train.pt')
if __name__ == '__main__':
main()
# python -m torch.distributed.launch --nproc_per_node 2 --master_port 1234 train_ddp.py --device_ids=0,1 --epoch=2