首先看看官方文档对它的解释
def __init__(self,
root: str,
transform: Any = None,
target_transform: Any = None,
loader: Any = default_loader,
is_valid_file: Any = None) -> None
A generic data loader where the images are arranged in this way:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Params:
root – Root directory path.
transform – A function/transform that takes in an PIL image and returns a transformed version. E.g,transforms.RandomCrop
target_transform – A function/transform that takes in the target and transforms it.
loader – A function to load an image given its path.
is_valid_file – A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files)
root –根目录路径。
transform–接受PIL图像并返回变换后的图像。
target_transform –接收目标并对其进行转换的函数/转换。
loader –加载给定路径的图像。
is_valid_file –获取图像文件路径并检查文件是否为有效文件的函数(用于检查损坏的文件)
由官方文档可以看出,torchvision.datasets.ImageFolder是一个通用数据加载器。并且要加载的所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片。
以猫狗图片为例,目录结构如下:
DogVSCats
|-----train
| |-----cat
| |-----cat01.JPG
| |-----cat02.JPG
| |-----dog
| |-----dog01.JPG
| |-----dog02.JPG
DogVSCats文件夹和.py文件在同级目录,则参数中的root即为DogVSCats\train
例子:
import torch
from torchvision import datasets, transforms
import os
'''
-------------------对数据进行载入-------------------------
'''
data_dir = "DogsVSCats"
data_transform = {x: transforms.Compose([transforms.Resize([64, 64]), transforms.ToTensor()])
for x in ["train", "valid"]} # 字典类型,key为train和valid,value是x: 后的内容
image_datasets = {x: datasets.ImageFolder(root=os.path.join(data_dir, x), transform=data_transform[x])
for x in ["train", "valid"]}
dataloader = {x: torch.utils.data.DataLoader(dataset=image_datasets[x], batch_size=16, shuffle=True)
for x in ["train", "valid"]}
x_example, y_example = next(iter(dataloader["train"]))
print("x_example个数", len(x_example))
print(type(x_example))
print("y_example个数", len(y_example))
输出:
x_example个数 16
<class ‘torch.Tensor’>
y_example个数 16