PyTorch 介绍 | TRANSFORMS

数据并不总是满足机器学习算法所需的格式。我们使用transform对数据进行一些操作,使得其能适用于训练。

所有的TorchVision数据集都有两个参数,用以接受包含transform逻辑的可调用项-transform 修改features,targe_transform 修改标签。torchvision.transforms提供了几种现成的常用转换操作。

FashionMNIST features是PIL Image格式,标签是整型。为了训练,我们需要将其转换为标准的tensors,并且标签是one-hot编码的tensor。为了完成这些转换,使用 ToTensorLambda

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor(),
    # 在创建的具有10个0值数组中,单独取第一个维度的y位置(原始标签),赋为1,即为one-hot编码
    target_tansform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0,
 torch.tensor(y), value=1))
)

输出:

点击查看代码
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

ToTensor()

ToTensor将PIL图像或NumPy ndarray 转换为 FloatTensor。并且将图片像素值缩放到范围[0., 1.]

Lambda Transforms

Lambda转换可使用任何用户定义的lambda函数。这里,我们定义了一个函数,可以将整型转换成one-hot编码的tensor,首先创建一个大小为10的0值tensor,根据给定标签 y得到索引位置,调用scatter_将其赋为1。

target_transform = Lambda(lambda y: torch.zeros(
    10,dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

延伸阅读

上一篇:java自定义分页对象


下一篇:离线安装kubeage