2021-11-02

pytorch中torchvision.transforms的一些理解

1.这个库里面主要是包含了一些图像处理的函数,也就是说使用.transforms的地方同样可以用其他图像库进行处理,例如opencv。
2.这个库一般只用于和torchvision.datasets一起使用的时候,其他的一般自己弄就行了。

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=BATCH_SIZE, shuffle=True)

3.我们使用pytorch的时候用的最多的就是这两句:

   transforms.ToTensor(),#归一化将shape为(H, W, C)的nump.ndarray或img转为shape为(C, H, W)的tensor
   transforms.Normalize((0.1307,), (0.3081,))  #标准化是为了加快收敛性 这里的0.1307和0.3081是MNIST数据集里的均值和标准差,因为只有一个通道,所以只写了一个这个东西一般是数据集提供方给出的。

对于其他的操作我们也可以用其他的库进行图像处理。

上一篇:Pytorch以单通道(灰度图)加载图片


下一篇:IMDB 电影评论情感分类数据集