LSTM的简洁实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class LSTM(nn.Module):
    def __init__(self, indim, hidim, outdim):
        super(LSTM, self).__init__()
        self.LSTM = nn.LSTM(indim, hidim, 2)# 设定层数为两层
        self.Linear = nn.Linear(hidim, outdim)
    
    def forward(self, input, h0, c0):
        input = input.type(torch.float32)
        state = state.type(torch.float32)
        if torch.cuda.is_available():
            input = input.cuda()
            state = state.cuda()
        state = (h0, c0)
        y, state = self.LSTM(input, state)
        y = self.Linear(y.reshape(-1, y.shape[-1]))
        return y, state

尝龟,没啥好说的,LSTM就是多了一个C状态,因为C状态也是需要初始化的,于是我们在用模型计算时候一定要注意传入的初始状态要包括h0 c0

其他的都是尝龟。

上一篇:D - 暴力(稍简单)


下一篇:PyTorch深度学习实践 第七讲 处理多维特征的输入