男女的身高和体重有着显著的差别,此次Python程序的任务是根据一个人的身高和体重,简单判断他(她)的性别。
采用最简单的单层神经网络,logistic regression模型,模型输入一个人身高和体重,判断性别男女。
训练样本是sex_train.txt的文本,部分训练样本数据如下,第一列数字为身高(m);第二列数字为体重(kg);第三列数字指的是性别,其中1指代男性,2指代女性。
代码需要自定义Dataset类和getitem读取数据
Dataset中读取数据并放入变量Data中,通过strip去掉空格和换行符,里面的words【0】,words【1】,words【2】分别代表读取身高,体重和性别。getitem根据索引取出数据。
len函数获取样本个数
网络模型的输入输出
训练的函数
实际训练是以Dataloader加载训练样本,并以批次进行训练,batchsize表示训练单元。
正式训练设置epochs表示学习的轮数,epochs=100,进行100轮的训练。
model.train()进入训练模式,model.eval进入检测模式。
第二次for循环是取一个批次的样本进行输出。
epochs = 100
for epoch in range(epochs):
# training-----------------------------------
model.train()
train_loss = 0
train_acc = 0
for batch, (batch_x, batch_y) in enumerate(train_loader):
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
out = model(batch_x)
loss = loss_func(out, batch_y)
train_loss += loss.item()
pred = torch.max(out, 1)[1]
train_correct = (pred == batch_y).sum()
train_acc += train_correct.item()
print('epoch: %2d/%d batch %3d/%d Train Loss: %.3f, Acc: %.3f'
% (epoch + 1, epochs, batch, math.ceil(len(train_data) / batchsize),
loss.item(), train_correct.item() / len(batch_x)))
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step() # 更新learning rate
print('Train Loss: %.6f, Acc: %.3f' % (train_loss / (math.ceil(len(train_data)/batchsize)),
train_acc / (len(train_data))))
# evaluation--------------------------------
model.eval()
eval_loss = 0
eval_acc = 0
for batch_x, batch_y in val_loader:
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
out = model(batch_x)
loss = loss_func(out, batch_y)
eval_loss += loss.item()
pred = torch.max(out, 1)[1]
num_correct = (pred == batch_y).sum()
eval_acc += num_correct.item()
print('Val Loss: %.6f, Acc: %.3f' % (eval_loss / (math.ceil(len(val_data)/batchsize)),
eval_acc / (len(val_data))))
# save model --------------------------------
if (epoch + 1) % 1 == 0:
torch.save(model.state_dict(), 'output/params_' + str(epoch + 1) + '.pth')
最后训练样本的效果
<iframe allowfullscreen="true" data-mediaembed="bilibili" id="MdllqxtL-1641700155130" src="https://player.bilibili.com/player.html?aid=850510347"></iframe>20220109