自定义模型的训练、验证与保存
完整的自定义模型,以CIFAR10为例
# encoding:utf-8
import torch
import torchvision
from torch import nn, optim
from torchvision import transforms as T
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# datasets
train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=T.ToTensor(), download=True)
val_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=T.ToTensor(), download=True)
# dataloader
train_loader = DataLoader(train_data, batch_size=64)
val_loader = DataLoader(val_data, batch_size=64)
# model definition
class CIFAR10(nn.Module):
def __init__(self):
super(CIFAR10, self).__init__()
self.model_1 = nn.Sequential(
nn.Conv2d(3, 32, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, 1, 2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64 * 4 * 4, 64),
nn.Linear(64, 10)
)
def forward(self, x):
return self.model_1(x)
# model initialization
model = CIFAR10()
# loss function and optimizer
loss_entropy = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2)
# logs
total_loss = SummaryWriter("./total_loss_log")
# cycle training
for i in range(10):
total_train_loss = 0
for data in train_loader:
imgs, targets = data
output = model(imgs)
loss = loss_entropy(output, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_loss += loss.item()
total_loss.add_scalar("total_train_loss_per_epoch", total_train_loss, i)
# cycle validation
for i in range(10):
total_val_loss = 0
with torch.no_grad:
for data in val_loader:
imgs, targets = data
output = model(imgs)
loss = loss_entropy(output, targets)
total_val_loss += loss.item()
total_loss.add_scalar("total_val_loss_per_epoch", total_val_loss, i)
total_loss.close()
# save model
torch.save(model, "cifar10.pth")
model.train() 和model.eval()
通常情况下,不使用model.train()和model.eval()不会影响模型的性能。
当模型中有一些特定的层,如Dropout、BatchNorm等,model.train()和model.eval()影响模型的性能。