Pytorch常用工具箱

神经网络工具箱nn

import torch.nn as nn

在nn中主要有两个重要模块:nn.Model、nn.functional,接着将分别介绍这两个模块。

nn.Model

nn.Model是nn的一个核心数据结构,最常用的做法就是继承nn.Model,如 class Nets(nn.Model),常用的全连接层、损失层、激活层、卷积层等都是nn.Model的子类,nn.Linear、nn.Conv2d等

nn.functional

import torch.nn.functional as F

性能方面与nn.Model有一些差异,但此处不做描述。调用常用的层时用nn.functional.xxx,如nn.functional.linear、nn.functional.conv2d。

utils.data

utils.data 主要包括4个类
(1)Dataset:是一个抽象类,其他数据需要继承这个类,并要覆写其中的两个方法(getitemlen
(2)DataLoader:定义了一个新的迭代器,实现批量(batch)读取,打乱数据(shuffle)等

from pytorch.utils.data import DataLoader

(3)random_split:把数据集随机拆分为给定长度的非重叠的新数据集。
(4)*sampler:多采样函数。

Torchvision

torchvision主要包括4个类
(1)datasets:常用数据集的加载,如MMIST,CIFAR10,设计上继承于torch.utils.data.Dataset
(2)models:提供经典的网络结构以及训练好的模型;
(3)transforms:常用的数据预处理操作,主要包括Tensor及PIL Image对象的操作;
(4)utils:包括两个函数,一个是make_grid,主要是将多张图片拼接在一起,一个是save_img,能将tensor保存成图片。

上一篇:Pytorch中nn.Moudle模块和nn.functional模块库的不同


下一篇:Go:functional Options编程模式