rnn--重新温习实现MNIST手写体识别

文章目录

题目

'''
Description: rnn--重新温习实现MNIST手写体识别
Autor: 365JHWZGo
Date: 2021-12-15 17:24:19
LastEditors: 365JHWZGo
LastEditTime: 2021-12-15 20:15:39
'''

问题

上一次写rnn手写体识别时,我用了batch_first=True,这次没有使用,重新理解了rnn中的维度变化

CrossEntropy

公式:torch.nn.CrossEntropyLoss()
rnn--重新温习实现MNIST手写体识别
在本例题中,我写的是

loss = loss_func(pre_out,label)

根据上述参数的要求
pre_out的size=(BATCH_SIZE,10),10也是类别数
label的size=(BATCH_SIZE,)

‘bool’ object is not iterable

这个问题出现在

accuracy = sum(pre_target == test_label.data.numpy())/2000.

这表示pre_target和test_label.data.numpy()的维度不统一,需要检查一下其维度大小,应该为(2000,)

一般出错在pre_target没有降维,没降之前的维度为(2000,1),直接用squeeze()降维

常见函数作用

函数名 作用
squeeze 移除数组中维度为1的维度
max output = torch.max(input, dim)
input是softmax函数输出的一个tensor dim是max函数索引的维度0/10是每列的最大值,1是每行的最大值
函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引
softmax dim:指明维度,dim=0表示按列计算;dim=1表示按行计算
torch将结果归一化
view 将维度展平

代码

import os
import torch
import torch.nn as nn
import torchvision
import torch.utils.data as Data
import torch.autograd.variable as Variable

torch.manual_seed(1)

# 超参数
BATCH_SIZE = 64
EPOCH = 1
LR = 0.01
DOWNLOAD_MNIST = False
TIME_STEP = 28
INPUT_SIZE = 28
HIDDEN_SIZE = 64

# 判断MNIST数据集是否已经下载
if not os.path.exists('./mnist') or not os.listdir('./mnist'):
    DOWNLOAD_MNIST = True

# 得到train_dataset
train_dataset = torchvision.datasets.MNIST(
    root='./mnist',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST
)

# 得到test_dataset
test_dataset = torchvision.datasets.MNIST(
    root='./mnist',
    train=False,
    transform=torchvision.transforms.ToTensor()
)

# 得到train_loader
train_loader = Data.DataLoader(
    dataset=train_dataset,
    shuffle=True,
    num_workers=2,
    batch_size=BATCH_SIZE
)

# 得到test_data
test_data = test_dataset.test_data[:2000]/255.
# 得到test_label
test_label = test_dataset.test_labels[:2000]

# 创建RNN类


class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        # lstm=(INPUT_SIZE,HIDDEN_SIZE, NUM_LAYER)
        self.lstm = nn.LSTM(
            input_size=INPUT_SIZE,
            hidden_size=HIDDEN_SIZE,
            num_layers=1
        )
        self.linear = nn.Linear(HIDDEN_SIZE, 10)

    def forward(self, x):
        # r_output=(TIME_STEP,BATCH_SIZE,HIDDEN_SIZE)
        # hn=(NUM_LAYER,BATCH_SIZE,HIDDEN_SIZE)
        # cn=(NUM_LAYER,BATCH_SIZE,HIDDEN_SIZE)
        r_output, (hn, cn) = self.lstm(x, None)
        clsify0to9 = self.linear(r_output[-1])
        return clsify0to9


if __name__ == '__main__':
    # 创建RNN实例
    rnn = RNN()
    # 创建优化器
    optim = torch.optim.Adam(rnn.parameters(), lr=LR)
    # 创建损失函数
    loss_func = nn.CrossEntropyLoss()

    # 训练
    for epoch in range(EPOCH):
        for i,(data,label) in enumerate(train_loader):
            # data=(BATCH_SIZE,CHANNELS,TIME_STEP,INPUT_SIZE)
            # label=(BATCH_SIZE)
            data = Variable(data.view(-1,TIME_STEP,INPUT_SIZE).transpose(0,1))
            label = Variable(label)
            # 使用rnn预测
            # rnn的输入维度为(TIME_STEP,BATCH_SIZE,INPUT_SIZE),所以需要展平为三个维度,并且第一个和第二个维度需要转变
            # rnn的输出维度为(BATCH_SIZE,10)
            pre_out = rnn(data)
            
            # 计算损失
            loss = loss_func(pre_out,label)

            # 优化
            optim.zero_grad()
            loss.backward()
            optim.step()
            if i % 100 == 0:
                # pre_test_label=(2000,10)
                # test_data.shape=[2000, 28, 28]
                # rnn的输入维度为(TIME_STEP,BATCH_SIZE,INPUT_SIZE),所以第一个和第二个维度需要转变
                pre_test_label = rnn(test_data.transpose(0,1))
                # input是softmax函数输出的一个tensor
                # dim是max函数索引的维度0/1,0是每列的最大值,1是每行的最大值
                # softmax dim
                pre_target = torch.max(torch.softmax(pre_test_label,1),dim=1)[1].data.numpy().squeeze()
                # pre_target需要降维
                accuracy = sum(pre_target == test_label.data.numpy())/2000.
                print(f'epoch:{epoch} accuracy:{accuracy}')

运行结果

rnn--重新温习实现MNIST手写体识别

总结

话说温故而知新,可以为师矣。
话真不假,我今天重学之后,受益匪浅,希望接下来几天,将注意力机制融入其中。

上一篇:ViT (Vision Transformer) ---- RNN


下一篇:使用字符RNN生成莎士比亚文本