太累了 看了一上午CSDN还是没搞明白
看的下面的up主的讲解 做一下笔记 免得忘记
在pytorch中自定义dataset读取数据_哔哩哔哩_bilibili
主要内容:如何划分训练集 验证集 数据读取 预处理
代码在github上 pytorch_classification文件夹下custom_dataset文件夹中,内有main.py my_dataset.py utils.py三个py文件
先看main.py文件
import os
import torch
from torchvision import transforms
from my_dataset import MyDataSet
from utils import read_split_data, plot_data_loader_image
# http://download.tensorflow.org/example_images/flower_photos.tgz
root = "/home/wz/my_github/data_set/flower_data/flower_photos" # 数据集所在根目录
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
train_data_set = MyDataSet(images_path=train_images_path,
images_class=train_images_label,
transform=data_transform["train"])
batch_size = 8
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers'.format(nw))
train_loader = torch.utils.data.DataLoader(train_data_set,
batch_size=batch_size,
shuffle=True,
num_workers=nw,
collate_fn=train_data_set.collate_fn)
# plot_data_loader_image(train_loader)
for step, data in enumerate(train_loader):
images, labels = data
if __name__ == '__main__':
main()
1需要更改第11行root为自己数据集的位置 且文件夹下包含的文件夹名字即为他们的标签
2用read_split_data划分训练集和 验证集
main文件中查看read_split_data,跳转到utils第九行
#val_rate划分验证集占所有样本的比例 默认值是0.2
def read_split_data(root: str, val_rate: float = 0.2):
random.seed(0) # 保证随机结果可复现 不管在谁的电脑上划分的数据集都一样
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
# 遍历文件夹,一个文件夹对应一个类别
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
# 排序,保证顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_class))
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
train_images_path = [] # 存储训练集的所有图片路径
train_images_label = [] # 存储训练集图片对应索引信息
val_images_path = [] # 存储验证集的所有图片路径
val_images_label = [] # 存储验证集图片对应索引信息
every_class_num = [] # 存储每个类别的样本总数
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
# 遍历每个文件夹下的文件
utils第54行 设置为True即可以将样本的数量可视化,在第71行return语句返回四个值到main函数中,在return语句设置断点 debug一下main函数
运行结果
Connected to pydev debugger (build 212.5284.44) using cuda device. 3670 images were found in the dataset. 2939 images for training.
3main.py中预处理图片
my_dataset.py 中第20行如果使用的不是RGB图像可以自行更改
使用PIL来预处理图片(也可以用opencv 但pytorch中用pil的预处理较多)