代码:
import torch.nn as nn
import torch
import torch.nn.functional as F
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import argparse
import os
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = 'Pytorch-cifar10_classification')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs')
parser.add_argument('--batch_size', type=int, default=32, help='size of each image batch')
parser.add_argument('--num_classes', type=int, default=10, help='number of classes')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
parser.add_argument('--pretrained_weights', type=str, help='pretrained weights')
parser.add_argument("--img_size", type=int, default=224, help="size of each image dimension")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")
args = parser.parse_args()
print(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# os.makedirs() 方法用于递归创建目录
os.makedirs("output", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
# transform
data_transform = transforms.Compose([transforms.ToTensor(),
transforms.RandomResizedCrop(args.img_size)
])
# 下载训练数据集
trian_data = datasets.CIFAR10(root = 'data',
train = True,
download = True,
transform = data_transform,
target_transform = None,
)
# 下载测试数据集
test_data = datasets.CIFAR10(root = "data",
train = False,
download = True,
transform = data_transform,
target_transform = None)
# 加载数据
train_loader = DataLoader(dataset = trian_data,
batch_size = args.batch_size,
shuffle = True)
test_loader = DataLoader(dataset = test_data,
batch_size = args.batch_size)
# 创建模型,使用预训练好的权重
model = models.vgg16(pretrained = True)
# 冻结模型,参数不更新
for para in model.parameters():
para.requires_grad = False
# 只训练全连接层
model.classifier[3].requires_grad = True
model.classifier[6].requires_grad = True
# 修改vgg16的输出维度
model.classifier[6] = nn.Linear(in_features=4096, out_features=args.num_classes, bias=True)
model = model.to(device)
# 打印网络结构
print(model)
# 定义优化器(也可以选择其他优化器)
optimizer = torch.optim.SGD(model.parameters(), lr = args.lr, momentum = args.momentum)
# optimizer = torch.optim.Adam(model.parameters())
# 训练
for epoch in range(1, args.epochs+1):
model.train()
for batch_index,data in enumerate(train_loader):
images, labels = data
images = images.to(device)
labels = labels.to(device)
# forward
output = model(images)
loss = F.cross_entropy(output, labels)
# backward
optimizer.zero_grad() # 梯度清空
loss.backward() # 梯度回传,更新参数
optimizer.step()
# 打印loss
print(f'Epoch:{epoch},Batch ID:{batch_index}/{len(train_loader)}, loss:{loss}')
# 保存模型
if epoch % args.checkpoint_interval == 0:
torch.save(model.state_dict(), f'checkpoints/cifar10_%d.pth' %epoch)
说明:
cifar10数据集可以通过trochvision中的datasets.CIFAR10下载,也可以自己下载(注意存放路径);我模型使用的是torchvision中的models中预训练好的vgg16网络(我只训练了全连接层),也可以自己搭建网络。