Pytorch 第一个程序

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets  #这里指定当前数据集为torchvision
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt

1. DataLoader

是Pytorch用来加载数据的常用的类,返回一个可遍历的数据集对象

传入参数:

  • dataset (Dataset) – dataset from which to load the data.

  • batch_size (intoptional) – how many samples per batch to load (default: 1).

  • shuffle (booloptional) – set to True to have the data reshuffled at every epoch (default: False)

2. torchvision

是一个包,里面包含了很多常用的视觉数据集。类似的还有torchtext, torchaudio,...

 

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

1. torchvision.datasets里的所有datasets(这里是FashionMNIST)都是torch.utils.data.Dataset的子类,因为这些子类都写了__getitem__和__len__,所以可以被传入torch.utils.data.DataLoader。

 2. FashionMNIST的属性有:

  • root (string) – Root directory of dataset where FashionMNIST/processed/training.pt and FashionMNIST/processed/test.pt exist.

  • train (booloptional) – If True, creates dataset from training.pt, otherwise from test.pt.

  • download (booloptional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

  • transform (callableoptional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

 

Pytorch 第一个程序

上一篇:clickhouse三他节点部署,整理的有点乱,明天在重新整理一下。


下一篇:前端和后端的区别