'''
Author: 365JHWZGo
Description: 7.批训练
Date: 2021-10-23 20:36:43
FilePath: \pytorch\day06-3.py
LastEditTime: 2021-10-23 21:27:29
LastEditors: 365JHWZGo
'''
#批训练的作用:
#展示每次所训练的数据
# 导包
import torch
import torch.utils.data as Data
torch.manual_seed(1) #固定初始化值
BATCH_SIZE = 5 #每批训练个数
#创建数据
x = torch.linspace(1, 10, 10) # [1,2,3......10]
y = torch.linspace(10, 1, 10) # [10,9,8......1]
#转化为torch能识别的Dataset
torch_dataset = Data.TensorDataset(x, y)
#print(torch_dataset)
#print('='*40)
#加载器
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE, #每批训练个数
shuffle=True,
num_workers=2 #多线程来读取数据
)
def show_batch():
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
print(
'epoch', epoch,
'step:', step,
'batch_x', batch_x.numpy(),
'batch_y', batch_y.numpy()
)
if __name__ == '__main__':
show_batch()