数据加载
Dataset
from torch.utils.data import Dataset
- 继承Dataset(其中,init、getitem、len需要自定义)
class MyClass(Dataset):
def __init__ (self, root_dir, label_dir):
self.root = root_dir
self.label = label_dir
self.pic = os.listdir(self.root)
def __getitem__ (self, index):
# 注:img的格式为PIL,维度为H W C
# 如需转为ndarray,可以使用np.array(img)转换
img = Image.open(self.pic[index])
label = self.label
return img, label
def __len__ (self):
return len(items)
dataset_1 = Myclass(root_dir, label_dir)
img, label = dataset_1[0]
- 补充
相同的dataset实例可以进行“加法”操作,实现数据集的拼接。
DataLoader
from torch.utils import