pytorch笔记:torchvision.datasets.ImageFolder

首先看看官方文档对它的解释

def __init__(self,
      root: str,
      transform: Any = None,
      target_transform: Any = None,
      loader: Any = default_loader,
      is_valid_file: Any = None) -> None
A generic data loader where the images are arranged in this way:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
 
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
 
Params:
root – Root directory path.
transform – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform – A function/transform that takes in the target and transforms it.
loader – A function to load an image given its path.
is_valid_file – A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files)

root –根目录路径。
transform–接受PIL图像并返回变换后的图像。
target_transform –接收目标并对其进行转换的函数/转换。
loader –加载给定路径的图像。
is_valid_file –获取图像文件路径并检查文件是否为有效文件的函数(用于检查损坏的文件)

由官方文档可以看出,torchvision.datasets.ImageFolder是一个通用数据加载器。并且要加载的所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片。

以猫狗图片为例,目录结构如下:

DogVSCats
|-----train
|  |-----cat
|    |-----cat01.JPG
|    |-----cat02.JPG
|  |-----dog
|    |-----dog01.JPG
|    |-----dog02.JPG

DogVSCats文件夹和.py文件在同级目录,则参数中的root即为DogVSCats\train

例子:

import torch
from torchvision import datasets, transforms
import os

'''
-------------------对数据进行载入-------------------------
'''
data_dir = "DogsVSCats"
data_transform = {x: transforms.Compose([transforms.Resize([64, 64]), transforms.ToTensor()])
                  for x in ["train", "valid"]}  # 字典类型,key为train和valid,value是x: 后的内容
image_datasets = {x: datasets.ImageFolder(root=os.path.join(data_dir, x), transform=data_transform[x])
                  for x in ["train", "valid"]}
dataloader = {x: torch.utils.data.DataLoader(dataset=image_datasets[x], batch_size=16, shuffle=True)
              for x in ["train", "valid"]}

x_example, y_example = next(iter(dataloader["train"]))
print("x_example个数", len(x_example))
print(type(x_example))
print("y_example个数", len(y_example))

输出:
x_example个数 16
<class ‘torch.Tensor’>
y_example个数 16

上一篇:Python 实现用户登录系统 案例一(基于hashlib & sys)


下一篇:【Python系列专栏】第四十篇 Python中常用内建模块(hashlib)