Pytorch学习记录(一)数据加载

数据加载

Dataset

  • 导入Dataset
from torch.utils.data import Dataset
  • 继承Dataset(其中,init、getitem、len需要自定义)
class MyClass(Dataset):
	def __init__ (self, root_dir, label_dir):
		self.root = root_dir
		self.label = label_dir
		self.pic = os.listdir(self.root)
	def __getitem__ (self, index):
		# 注:img的格式为PIL,维度为H W C
		# 如需转为ndarray,可以使用np.array(img)转换
		img = Image.open(self.pic[index])
		label = self.label
		return img, label
	def __len__ (self):
		return len(items)
  • 实例化
	dataset_1 = Myclass(root_dir, label_dir)
	img, label = dataset_1[0]
  • 补充
    相同的dataset实例可以进行“加法”操作,实现数据集的拼接。

DataLoader

  • 导入DataLoader
from torch.utils import
上一篇:【对讲机的那点事】你会北斗天地公共位置平台微信版操作吗?


下一篇:Python 脚本一个要注意的点