Pytorch学习记录(七)自定义模型的训练、验证与保存

自定义模型的训练、验证与保存

完整的自定义模型,以CIFAR10为例

# encoding:utf-8
import torch
import torchvision
from torch import nn, optim
from torchvision import transforms as T
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# datasets
train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=T.ToTensor(), download=True)
val_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=T.ToTensor(), download=True)

# dataloader
train_loader = DataLoader(train_data, batch_size=64)
val_loader = DataLoader(val_data, batch_size=64)

# model definition
class CIFAR10(nn.Module):
    def __init__(self):
        super(CIFAR10, self).__init__()
        self.model_1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )
    def forward(self, x):
        return self.model_1(x)

# model initialization
model = CIFAR10()

# loss function and optimizer
loss_entropy = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2)

# logs
total_loss = SummaryWriter("./total_loss_log")

# cycle training
for i in range(10):
    total_train_loss = 0
    for data in train_loader:
        imgs, targets = data
        output = model(imgs)
        loss = loss_entropy(output, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
    total_loss.add_scalar("total_train_loss_per_epoch", total_train_loss, i)

# cycle validation
for i in range(10):
    total_val_loss = 0
    with torch.no_grad:
        for data in val_loader:
            imgs, targets = data
            output = model(imgs)
            loss = loss_entropy(output, targets)
            total_val_loss += loss.item()
    total_loss.add_scalar("total_val_loss_per_epoch", total_val_loss, i)

total_loss.close()

# save model
torch.save(model, "cifar10.pth")

model.train() 和model.eval()

通常情况下,不使用model.train()和model.eval()不会影响模型的性能。
当模型中有一些特定的层,如Dropout、BatchNorm等,model.train()和model.eval()影响模型的性能。

上一篇:杰理之获取录音播放总事件【篇】


下一篇:【数据结构与算法】算法的时间复杂度