import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets #这里指定当前数据集为torchvision
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
1. DataLoader
是Pytorch用来加载数据的常用的类,返回一个可遍历的数据集对象
传入参数:
-
dataset (Dataset) – dataset from which to load the data.
-
batch_size (int, optional) – how many samples per batch to load (default:
1
). -
shuffle (bool, optional) – set to
True
to have the data reshuffled at every epoch (default:False
)
2. torchvision
是一个包,里面包含了很多常用的视觉数据集。类似的还有torchtext, torchaudio,...
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
1. torchvision.datasets里的所有datasets(这里是FashionMNIST)都是torch.utils.data.Dataset的子类,因为这些子类都写了__getitem__和__len__,所以可以被传入torch.utils.data.DataLoader。
2. FashionMNIST的属性有:
-
root (string) – Root directory of dataset where
FashionMNIST/processed/training.pt
andFashionMNIST/processed/test.pt
exist. -
train (bool, optional) – If True, creates dataset from
training.pt
, otherwise fromtest.pt
. -
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
-
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop