Pytorch教程(2)--0. 快速入门

本节介绍机器学习中常见任务的API。请参阅每一节中的链接以深入了解。

一、读取数据集
PyTorch有两个原生类型来处理数据:
torch.utils.data.Dataset:存储样本及其相应的标签;
torch.utils.data.DataLoader:给Dataset包装了一个iterable。

PyTorch提供特定领域的库,如TorchText、TorchVision和TorchAudio,所有这些库都包含数据集。在本教程中,我们将使用TorchVision数据集。
torchvision.datasets模块包含许多实际视觉数据的数据集对象,如CIFAR、COCO(完整列表)。在本教程中,我们使用FashionMNIST数据集,该数据集来自美国国家标准与技术研究所,由 250 个不同人手写的数字构成,数字以二进制图片的形式存储,每个图片28x28像素。
下载地址:http://yann.lecun.com/exdb/mnist/
数据集包括四个包,内容如下表所示。
训练用图像集 train-images-idx3-ubyte.gz 包含60,000 个样本
训练用标签集 train-labels-idx1-ubyte.gz 包含60,000 个标签
测试用图像集 t10k-images-idx3-ubyte.gz 包含 10,000 个样本
测试用标签集 t10k-labels-idx1-ubyte.gz 包含 10,000 个标签

每个TorchVision数据集包含两个参数:transform和target_transform,分别修改样本和标签。
我们将数据集作为参数传递给DataLoader。这在我们的数据集上包装了一个iterable,并支持自动批处理、采样、洗牌(数据打乱)和多进程数据加载。这里我们将批数量定义为64,即dataloader iterable中的每个元素将返回一批64个特性和标签。

在这里插入代码片
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt


# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=False,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=False,
    transform=ToTensor(),
)

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

~未完待续

上一篇:06-数据的划分与sklearn中的数据集介绍


下一篇:(13)不支持的操作