pytorch中torchvision.transforms的一些理解
1.这个库里面主要是包含了一些图像处理的函数,也就是说使用.transforms的地方同样可以用其他图像库进行处理,例如opencv。
2.这个库一般只用于和torchvision.datasets一起使用的时候,其他的一般自己弄就行了。
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=BATCH_SIZE, shuffle=True)
3.我们使用pytorch的时候用的最多的就是这两句:
transforms.ToTensor(),#归一化将shape为(H, W, C)的nump.ndarray或img转为shape为(C, H, W)的tensor
transforms.Normalize((0.1307,), (0.3081,)) #标准化是为了加快收敛性 这里的0.1307和0.3081是MNIST数据集里的均值和标准差,因为只有一个通道,所以只写了一个这个东西一般是数据集提供方给出的。
对于其他的操作我们也可以用其他的库进行图像处理。