使用pytorch进行机器学习
- 1. Config and Seed 配置与种子数设置
- 2. Transform/Pre-processing 预处理
- 3. Dataset 数据集创建
- 4. Dataloader 数据加载
- 5. Model 网络模型
- 6. Optimizer 优化器
- 7. Loss 损失函数
- 8. Train 训练
- 9. Val and Save 验证与保存模型
- 10. Test 测试模型
- 11. Other 其它
个人学习笔记,使用pytorch搭建学习框架中用到的基本结构与一些代码的写法,包含详细说明与部分样例,部分内容未完成,将随着学习进度不定时补充,喜欢请点赞~
部分参考代码来自网络教程与kaggle的notebook,文中的代码并非直接就能运行的某个项目的完整代码,大部分是写法的样例,需要自己进行具体细节修改。
1. Config and Seed 配置与种子数设置
创建config类,将大部分常用的配置项集中在一起,方便在其它代码中使用[config.属性名]调用。
此处举例一些自己使用过的配置内容,具体配置项根据项目具体情况调整。
class config:
# 设置种子数
seed = 26
# 是否使用种子数
use_seed = True
# 用于图像处理,配置图像输入网络的长宽
img_size = 512
# 使用哪一种网络模型,此处配置仅为字符串,由其它代码实现调用
model_name = "efficientnet"
# model_name = "resnet50"
# model_name = "resnext50_32x4d"
# 是否使用旧模型继续训练
# 如果为是,直接加载保存的旧模型参数并继续训练
# 如果为否,加载随机模型参数或者预训练模型参数
from_old_model = True
# 是否只训练输出层,用于模型迁移时,先仅对预训练模型的输出层训练
only_train_output_layer = False
# 用于图片处理,是否使用图像增强的方式扩充训练集
use_image_enhancement = True
# 设置优化器的学习率
learning_rate = 1e-4
# 使用哪一种优化器,此处配置仅为字符串,由其它代码进行具体实现
# optimizer_name = "SGD"
# optimizer_name = "Adam"
optimizer_name = "AdamW"
# 代码最大运行多少个epoch后中止
epochs = 50
# batch size的大小
batchSize = 8
# 设置一个最小初始保存正确率,只在磁盘上保存正确率最高的模型
# 用于以accuracy为标准的训练,并非适用于所有类型的数据集
# 根据数据集具体情况调整使用哪种参数作为保存标准
lowest_save_acc = 0
# 损失函数,此处例子为torch自带的多标签损失函数
criterion = nn.MultiLabelSoftMarginLoss()
# 如果要从数据集中分割一定比例作为验证集,可以在这里配置大小
size_of_val_dataset = 0.2
# 用于交叉验证的情况下
# 将数据集以20%的比例分割为五份,编号为0-4
# 此时取出一个部分作为验证集,将另外四份作为训练集,训练5个模型
# 为了方便,可以直接在这片配置将哪个索引的部分作为验证集
# 具体实现由其它代码完成
val_index = 0
# 配置模型输出频道,比如在分类任务中要输出多少种不同的分类
output_channel = 5
# 将读取的数据储存在哪里,写了两个适用于不同情况的dataset便于切换
# 总量较小的数据可以完全储存在内存中,减少磁盘访问次数,从而提高访问速度和节省磁盘寿命
# 总量较大的数据无法完全装入内存,例如大型图片数据集,只能在每次训练的时候从磁盘上读取,存放在固态可以大幅提高访问速度。
# read_data_from = "Memory"
read_data_from = "Disk"
# 路径配置,此处纯举例,例如从csv,txt,文件夹读取训练数据与模型数据
train_csv_path = "train.csv"
noise_path = "noise.txt"
train_image = "train_images/"
log_name = "log.txt"
model_path = "trained_models/save_model_" + model_name + "_" + str(val_index) + ".pth"
# 配置一个记录最佳正确率的模型,记录该模型在第几个epoch出现,正确率为多少
# 将epoch编号初始化为-1,保存最低正确率从配置文件中读取
best_val_acc = (-1, config.lowest_save_acc)
# 固定随机数种子,从而实现复现
def seed_torch(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# 仅当配置启用时,固定随机数种子
if config.use_seed:
seed_torch(seed=config.seed)
2. Transform/Pre-processing 预处理
2.1 图片的预处理
使用albumentations进行图像增强比使用torch自带的库要强大得多,此处示例仅展示了训练和测试中分别用到的两个图像处理方式的大致写法,具体内容根据具体情况调整。
包括了随机的图片裁剪缩放、其它图像增强方式和转换为tensor格式。
from albumentations import (Blur,Flip,ShiftScaleRotate,GridDistortion,ElasticTransform,HorizontalFlip,CenterCrop,RandomResizedCrop,
HueSaturationValue,Transpose,RandomBrightnessContrast,CLAHE,RandomCrop,Cutout,CoarseDropout,
CoarseDropout,Normalize,ToFloat,OneOf,Compose,Resize,RandomRain,RandomFog,Lambda
,ChannelDropout,ISONoise,VerticalFlip,RandomGamma,RandomRotate90,RandomSizedCrop,ToGray,BboxParams,MotionBlur,MedianBlur)
from albumentations.pytorch import ToTensorV2
import cv2
# 获取训练的转换方式
def get_train_transforms(img_size):
return Compose(
[RandomResizedCrop(img_size, img_size),
#RandomCrop(224, 224),
OneOf([
RandomGamma(gamma_limit=(60, 120), p=0.9),
RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.9),
CLAHE(clip_limit=4.0, tile_grid_size=(4, 4), p=0.9),
]),
OneOf([
Blur(blur_limit=3, p=1),
MotionBlur(blur_limit=3, p=1),
MedianBlur(blur_limit=3, p=1)
], p=0.5),
HorizontalFlip(p=0.5),
VerticalFlip(p=0.5),
HueSaturationValue(hue_shift_limit=0.2,sat_shift_limit=0.2,val_shift_limit=0.2,p=0.5),
ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=20,
interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_CONSTANT, p=1),
Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
CoarseDropout(p=0.5),
ToTensorV2(p=1.0),]
)
# 获取测试的转换方式
def get_test_transforms(img_size):
return Compose(
[Resize(img_size, img_size),
Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
ToTensorV2(p=1.0),]
)
2.2 文本的预处理(未完成)
3. Dataset 数据集创建
Dataset类决定了程序将如何从磁盘上读取现有的数据,包括txt,csv或者图片之类的数据。
一般如果只是单纯的文本信息,可以全部存放在内存中,但涉及大量的图片这样的数据集时,从磁盘中依次读取可以防止电脑的内存溢出。
当然可以写两种dataset,分别用磁盘和内存储存,然后用一个if通过config内的配置项切换来方便调试。
Dataset主要有三个必写的基本部分:
from torch.utils.data import Dataset
class Leaf_train_Dataset(Dataset):
def __init__(self, data_csv, img_path, transform):
# 定义数据集自身的属性,可以导入csv文件,路径,转换方式
# 如果将数据存放在这里,那么所有数据会被保存在内存中
self.data = ...
self.transform = transform
def __getitem__(self, index):
# 定义如何从__init__定义的属性中按照index取出指定的数据
# 在训练时返回数据和label,在测试时只返回数据
# 如果只把csv文件保存在内存中,那么可以每次取出路径,然后在这里读取磁盘上的文件,将读取到的图片等信息返回
data, label = ...
return data, label
def __len__(self):
# 返回整个数据集的长度,自己定义如何获取这个长度
length = ...
return length
init中的数据保存在内存中,因此要把整个数据集都保存在内存中,就在init方法内完成所有读取工作。如果要逐批从磁盘读取数据,就在init中保存索引,然后在getitem中完成数据的磁盘读取部分。
使用Dataset读取数据,获得一个dataset的对象。
train_dataset = Leaf_train_Dataset(train_csv, config.train_image, transform=get_train_transforms(config.img_size))
如果要分割数据集等操作,可以在csv阶段就完成,也可以使用随机切分的函数直接切分Dataset,但是为了保证每次分割的结果一样,最好在csv阶段就完成数据分割。
4. Dataloader 数据加载
完成了创建Dataset后,将其转换为Dataloader用于训练输入。
batch_size为每次从dataset中取出的数据量,shuffle表示是否要打乱原本的顺序。
train_loader = DataLoader(dataset=train_dataset, batch_size=config.batchSize, shuffle=True)
5. Model 网络模型
训练的模型是torch中核心的部分,它决定了网络的架构,可以自己写一个自定义的模型,也可以导入现成的模型并且迁移预训练数据,在已有模型的基础上进行进一步训练。
5.1 自定义模型
5.1.1 自定义模型的基本结构
import torch.nn as nn
class my_network(nn.Module):
def __init__(self):
super(my_network, self).__init__()
# 在此定义神经网络中遇到的各个层
def forward(self, input):
# input为网络的输入数据,output为网络的输出数据
# 在此定义输入到输出过程中的一系列处理步骤
# 包括经过的层的顺序,激活函数,数据连接方式等
return output
5.1.2 只有一个线性层的网络
class NetLin(nn.Module):
# linear function followed by log_softmax
def __init__(self):
super(NetLin, self).__init__()
# 定义线性层输入为784个节点,输出为10个节点
self.linear = nn.Linear(28*28,10)
def forward(self, x):
# 将28x28的二阶矩阵输入展开为1x784的一阶矩阵
x = x.view(x.shape[0],-1)
# 将数据输入线性层,得到尺寸为1x10的输出数据
x = self.linear(x)
return x
5.1.3 有一个隐藏层的网络
class NetFull(nn.Module):
def __init__(self):
super(NetFull, self).__init__()
# 假设隐藏层有50个节点
hid_nodes_num = 50
# 第一个线性层输入为784个节点,输出50个节点
self.in_to_hid = nn.Linear(28*28,hid_nodes_num)
# 第二个线性层输入50个节点,输出10个节点
self.hid_to_out = nn.Linear(hid_nodes_num,10)
def forward(self, x):
# 将28x28的二阶矩阵输入展开为1x784的一阶矩阵
x = x.view(x.shape[0],-1)
# 用第一个线性层将数据从784个变为50个
hid_sum = self.in_to_hid(x)
# 在一个线性层之后使用tanh激活函数使得神经网络可以处理非线性问题
hidden = torch.tanh(hid_sum)
# 用第一个线性层将数据从50个变为10个
output = self.hid_to_out(hidden)
return output
5.1.4 有卷积层的网络
卷积层的具体定义和功能在此不再赘述,仅列出代码的具体写法。
class NetConv(nn.Module):
# two convolutional layers and one fully connected layer,
# all using relu, followed by log_softmax
def __init__(self):
super(NetConv, self).__init__()
# 定义第一个卷积层
self.conv1 = nn.Conv2d(1,32,kernel_size=3)
# 定义第二个卷积层
self.conv2 = nn.Conv2d(32,64,kernel_size=3)
# 定义第一个线性层
self.linear1 = nn.Linear(36864,200)
# 定义第二个线性层
self.linear2 = nn.Linear(200,10)
# 定义激活函数
self.ReLU = nn.ReLU()
def forward(self, x):
# 卷积层的输入不需要展开为一维,每层后接一次激活函数
x = self.conv1(x)
x = self.ReLU(x)
x = self.conv2(x)
x = self.ReLU(x)
# 将卷积层的二维输入放入线性层之前要将其展开为一维矩阵
x = x.view(x.shape[0],-1)
x = self.linear1(x)
x = self.ReLU(x)
output = self.linear2(x)
return output
5.1.5 有跨层连接的网络
class ShortNet(torch.nn.Module):
def __init__(self, num_hid):
super(ShortNet, self).__init__()
# 定义线性层
self.in_to_hid1 = nn.Linear(2,num_hid)
self.in_to_hid2 = nn.Linear(2,num_hid)
self.hid1_to_hid2 = nn.Linear(num_hid,num_hid)
self.in_to_out = nn.Linear(2,1)
self.hid1_to_out = nn.Linear(num_hid,1)
self.hid2_to_out = nn.Linear(num_hid,1)
def forward(self, input):
hid_sum1 = self.in_to_hid1(input)
self.hidden1 = torch.tanh(hid_sum1)
# 将之前网络的输出与最初的输入数据分别处理后相加,从而实现跨层连接
hid_sum2 = self.in_to_hid2(input) + self.hid1_to_hid2(self.hidden1)
self.hidden2 = torch.tanh(hid_sum2)
# 将之前所有层的数据分别处理后都相加
out_sum = self.in_to_out(input) + self.hid1_to_out(self.hidden1) + self.hid2_to_out(self.hidden2)
# 使用sigmoid激活函数
output = torch.sigmoid(out_sum)
return output
5.1.6 封装处理序列
以5.1.4中的网络为例,将多个层的定义封装起来方便调用。
class NetConv(nn.Module):
# two convolutional layers and one fully connected layer,
# all using relu, followed by log_softmax
def __init__(self):
super(NetConv, self).__init__()
self.features = tnn.Sequential(
nn.Conv2d(1,32,kernel_size=3),
nn.ReLU(True),
nn.Conv2d(32,64,kernel_size=3),
nn.ReLU(True),
)
self.classifier = nn.Sequential(
nn.Linear(36864, 200),
nn.ReLU(True),
nn.Linear(200, 10),
)
def forward(self, x):
# 使用卷积层提取特征
features = self.features(img)
# 展平二维矩阵
features = features.view(features.shape[0],-1)
# 使用线性层分类
output = self.classifier(features)
return output
5.1.7 将输入的横纵坐标转换为极坐标
class PolarNet(torch.nn.Module):
def __init__(self, num_hid):
super(PolarNet, self).__init__()
self.in_to_hid = nn.Linear(2,num_hid)
self.hid_to_out = nn.Linear(num_hid,1)
def forward(self, input):
# 将x,y的输入转换为r和a的极坐标系
r = torch.norm(input,2,dim=-1)
a = torch.atan2(input[:,1],input[:,0])
# 使用torch.stack在第一个维度上连接数据
input_trans = torch.stack((r,a),1)
hid_sum = self.in_to_hid(input_trans)
self.hidden1 = torch.tanh(hid_sum)
out_sum = self.hid_to_out(self.hidden1)
output = torch.sigmoid(out_sum)
return output
torch.cat() 和 torch.stack()都可以用于连接数据。
torch.cat()对tensors沿指定维度拼接,但返回的Tensor的维数不会变。
torch.stack()同样是对tensors沿指定维度拼接,但返回的Tensor会多一维。
5.2 调用现有模型
5.2.1 修改输出频道数
大部分分类模型默认的输出为1000个分类,因此在使用时需要对模型的输出线性层进行微调。
注:并非所有模型的输出层名字都一样,例如efficientnet的输出层叫_fc,而resnet的输出层叫fc。
修改方式为使用一个新的线性层代替原有模型:
net = EfficientNet.from_name('efficientnet-b0')
net._fc = nn.Linear(net._fc.in_features, output_channel)
错误修改方式:直接修改输出频道数,这种方式虽然不会报错,但实际上模型并不会有任何改变,真实的输出仍然是1000,不仔细看很容易上当受骗。
错误样例:
# 这是错误样例!!!
net = EfficientNet.from_name('efficientnet-b0')
net._fc.out_features = output_channel
5.2.2 仅训练输出层
在很多现成可以调用的模型中,它们有已经训练好的预训练数据,完全可以在这些数据的基础上进行微调,但因为输出层的不一样,输出层的预训练数据可能是完全错误的。
在这种情况下,可以先不更新其它层的数据,仅仅在更新输出层数据的情况下先对输出层的参数进行训练。
假设我们已经在config里面设置了only_train_output_layer的参数,值为True或False。
if config.only_train_output_layer:
for name, value in net.named_parameters():
# 如果参数不是输出层的,就不更新
if name != "_fc.weight" and name != "_fc.bias" and name != "fc.weight" and name != "fc.bias":
value.requires_grad = False
# 移除所有不需要更新的参数
params = filter(lambda x: x.requires_grad, net.parameters())
# 如果没有启用配置,就对所有参数训练
else:
params = net.parameters()
# 将参数放入优化器
optimizer = toptim.SGD(params, lr=config.learning_rate)
5.2.3 经典的网络调用方式
EfficientNet的库需要自己下载安装。
from efficientnet_pytorch import EfficientNet
# 名字可以是b0-b7
# 这样是只加载网络
net = EfficientNet.from_name('efficientnet-b0')
# 这样是加载网络的同时加载预训练数据
net = EfficientNet.from_pretrained('efficientnet-b0')
ResNet50
from torchvision import models
# 配置是否要加载预训练数据
net = models.resnet50(pretrained=False)
Resnext50_32x4d
注:整个模型的预训练数据似乎总是下载不动……可以配置为False防止卡住。
import timm
# 配置是否要加载预训练数据
self.model = timm.create_model('resnext50_32x4d', pretrained=False)
5.2.4 加载之前保存的模型
model_path = "save_model_efficientnet.pth"
net = EfficientNet.from_name('efficientnet-b0')
# 加载保存的模型参数
net.load_state_dict(torch.load(model_path))
如何保存模型参数到磁盘详见第9部分。
6. Optimizer 优化器
优化器是网络学习过程中用于更新网络参数的工具。
假设已经在config中配置了learning_rate作为初始学习率。
6.1 SGD
最稳定,最保守,最终效果大概率最好的优化器,就是有亿点点慢……
optimizer = toptim.SGD(params, lr=config.learning_rate)
6.2 Adam
高速训练优化器,但是最终的模型效果不能保证。
optimizer = toptim.Adam(params, lr=config.learning_rate)
6.3 AdamW
在Adam基础上增加了L2正则化效果的优化器,理论上比Adam的最终效果应该更好……恩,理论上。
需要额外配置一个参数。
from torch.optim import AdamW
optimizer = AdamW(params, lr=config.learning_rate, weight_decay=1e-6)
7. Loss 损失函数
损失函数有很多种,一般来说有连续型和离散型两类。
平方损失函数(Square Loss)和绝对值损失函数(Absolute Value Loss)适用于连续值的学习。
# 平均绝对值误差损失
criterion = torch.nn.L1Loss()
# 均方误差损失
criterion = torch.nn.MSELoss()
交叉熵损失(Cross-Entropy Loss)适用于分类问题的离散值学习。
离散值中又有二分类、多分类、多标签(一个数据可能同时有多个标签)的区别。
# 二分类交叉熵,输入为一组数据和一个标签
criterion = nn.BCEWithLogitsLoss()
# 多分类交叉熵,输入为一组数据和一个标签
criterion = nn.CrossEntropyLoss()
# 多标签交叉熵,输入为一组数据和一组正确的标签
criterion = nn.MultiLabelSoftMarginLoss()
此处仅仅举出少数例子,更多损失函数请查询相关文档。
还有一些无法直接调用,需要从网上复制粘贴 自己编写的损失函数计算方式,例如:
用于应对错误样本标注的标签平滑LabelSmoothingLoss()
用于应对样本数量极度不平衡的FocalLoss()
8. Train 训练
为了简洁方便,将训练过程写成一个单独的函数。
在main部分中写:
# start train
train(net, train_loader, config.criterion, optimizer, epoch, device, log)
然后调用一个专门的train()函数:
def train(net, train_loader, criterion, optimizer, epoch, device, log):
# 记录训练中的loss
runningLoss = 0
loss_count = 0
# 为了便于显示进度
batch_num = len(train_loader)
# 将dataloader中的数据依次取出
for index, (imgs, labels) in enumerate(train_loader):
# 如果使用GPU加速训练,将数据放入GPU的内存
imgs, labels = imgs.to(device), labels.to(device)
# 将grad重置为0
optimizer.zero_grad()
# 网络的forward操作
# 将输入数据放入模型,得到输出数据
output = net(imgs)
# 计算loss
# 此处需要注意数据格式符合损失函数的要求,这里用的是多标签交叉熵
loss = criterion(output, labels.long())
# 累计loss
runningLoss += loss.item()
loss_count += 1
# 反向传递
loss.backward()
# 更新参数
optimizer.step()
# 每隔400个批次就打印一次当前进度与损失,方便查看
if (index + 1) % 400 == 0:
print("Epoch: %2d, Batch: %4d / %4d, Loss: %.3f" % (epoch + 1, index + 1, batch_num, loss.item()))
# 计算一个epoch中的平均loss
avg_loss = runningLoss / loss_count
print("For Epoch: %2d, Average Loss: %.3f" % (epoch + 1, avg_loss))
log.write("For Epoch: %2d, Average Loss: %.3f" % (epoch + 1, avg_loss) + "\n")
9. Val and Save 验证与保存模型
每个epoch跑完后,都进行一次验证,假设以accuracy为标准,只保留所有epoch中accuracy最高的那个模型。
为了简洁方便,将验证过程写成一个单独的函数。
在main部分中写:
# start val
val(net, val_loader, config.criterion, optimizer, epoch, device, log, train_start)
然后调用一个专门的val()函数:
def val(net, val_loader, criterion, optimizer, epoch, device, log, train_start):
# 告诉模型net进入了验证的过程
net.eval()
# 在验证过程中关闭torch的自动构建计算图
with torch.no_grad():
# 记录长度
total_len = 0
correct_len = 0
# 全局中目前最高accuracy的模型记录是多少
global best_val_acc
# 取出验证集中的数据依次输入
for index, (imgs, labels) in enumerate(val_loader):
# 将数据复制到GPU内存
imgs, labels = imgs.to(device), labels.to(device)
# 网络模型的输入输出
output = net(imgs)
# 分类任务中使用argmax取出概率最大的标签作为最终结论,然后展平矩阵
pred = output.argmax(dim=1, keepdim=True).flatten()
# 让label的标签格式保持和pred一致
labels = labels.argmax(dim=1, keepdim=True).flatten()
# 统计预测正确的数量
assessment = torch.eq(pred, labels)
total_len += len(pred)
correct_len += int(assessment.sum())
# 计算accuracy
accuracy = correct_len / total_len
print("Start val:")
print("accuracy:", accuracy)
log.write("accuracy: " + str(accuracy) + "\n")
# 如果accuracy超出了历史最高纪录,保存模型
if accuracy > best_val_acc[1]:
# 更新历史最高记录
best_val_acc = (epoch+1, accuracy)
# 保存模型
torch.save(net.state_dict(), config.model_path)
print("Model saved in epoch "+str(epoch+1)+", acc: "+str(accuracy)+".")
log.write("Model saved in epoch "+str(epoch+1)+", acc: "+str(accuracy)+".\n")
# 打印到目前为止所用的运行时间,便于参考
current_time = time.time()
pass_time = int(current_time - train_start)
time_string = str(pass_time // 3600) + " hours, " + str((pass_time % 3600) // 60) + " minutes, " + str(
pass_time % 60) + " seconds."
print("Time pass:", time_string)
print()
log.write("Time pass: " + time_string + "\n\n")
10. Test 测试模型
测试部分和验证部分很相似,只不过少了对比部分,只要将输出直接保存到磁盘即可。
注:测试代码可以写在别的文件中,直接从磁盘加载训练好的模型参数文件载入网络,而测试的Dataset不需要返回label。
给出项目中的局部代码作为参考。
result = []
with torch.no_grad():
batch_num = len(test_loader)
for index, image in enumerate(test_loader):
image = image.to(device)
output = net(image)
pred = output.argmax(dim=1, keepdim=True)
pred = pred.view(pred.shape[0], -1)
result = result + list(map(lambda x:int(x), pred))
if (index + 1) % 10 == 0:
print("Batch: %4d / %4d" % (index + 1, batch_num))
# 将最终结果保存到CSV文件
pred_result = pd.concat([pd.DataFrame(file_list, columns=['image_id']), pd.DataFrame(result, columns=['label'])], axis=1)
pred_result.to_csv(output_path + "submission.csv", index=False, sep=',')
11. Other 其它
11.1 GPU加速
前置要求:安装好cuda与GPU版本的pytorch。
测试是否可用:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Use " + str(device))
在训练前,将tensor数据复制到训练所用设备的内存中:
data = data.to(device)
注:部分只能在cpu上进行的运算,如numpy库中的部分运算,需要先将数据复制回cpu的内存中才能进行。
大部分使用过程在之前的训练和验证代码中其实已经给出。
11.2 用Apex减少显存占用与训练时间
Apex是一种将训练时间和显存占用几乎减半的方式,只需要修改代码中的三个地方就可以简单实现。
假设已经在conifg中配置了use_apex的属性用于迅速是否使用apex。
在main()部分中添加:
from apex import amp
# 混合精度加速
if config.use_apex:
net, optimizer = amp.initialize(net, optimizer, opt_level="O1")
在train()部分中对反向传播进行修改:
if config.use_apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
如有收获,欢迎点赞,如有错误,欢迎大佬们指出问题帮助修正。