7.批训练

'''
Author: 365JHWZGo
Description: 7.批训练
Date: 2021-10-23 20:36:43
FilePath: \pytorch\day06-3.py
LastEditTime: 2021-10-23 21:27:29
LastEditors: 365JHWZGo
'''
#批训练的作用:
#展示每次所训练的数据

# 导包
import torch
import torch.utils.data as Data
torch.manual_seed(1)    #固定初始化值

BATCH_SIZE = 5          #每批训练个数

#创建数据
x = torch.linspace(1, 10, 10)  # [1,2,3......10]
y = torch.linspace(10, 1, 10)  # [10,9,8......1]

#转化为torch能识别的Dataset
torch_dataset = Data.TensorDataset(x, y)
#print(torch_dataset)
#print('='*40)

#加载器
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,      #每批训练个数
    shuffle=True,
    num_workers=2               #多线程来读取数据
)

def show_batch():
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):
            print(
                'epoch', epoch,
                'step:', step,
                'batch_x', batch_x.numpy(),
                'batch_y', batch_y.numpy()
            )

if __name__ == '__main__':
    show_batch()

7.批训练

上一篇:线性回归的从零开始实现


下一篇:Pytorch归一化方法讲解与实战:BatchNormalization、LayerNormalization、nn.BatchNorm1d和LayerNorm()