import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class LSTM(nn.Module):
def __init__(self, indim, hidim, outdim):
super(LSTM, self).__init__()
self.LSTM = nn.LSTM(indim, hidim, 2)# 设定层数为两层
self.Linear = nn.Linear(hidim, outdim)
def forward(self, input, h0, c0):
input = input.type(torch.float32)
state = state.type(torch.float32)
if torch.cuda.is_available():
input = input.cuda()
state = state.cuda()
state = (h0, c0)
y, state = self.LSTM(input, state)
y = self.Linear(y.reshape(-1, y.shape[-1]))
return y, state
尝龟,没啥好说的,LSTM就是多了一个C状态,因为C状态也是需要初始化的,于是我们在用模型计算时候一定要注意传入的初始状态要包括h0 c0
。
其他的都是尝龟。