1 import torch
2 import torch.nn as nn
3
4 torch.random.manual_seed(10)
5
6 input_size = 2 # 输入向量维度
7 hidden_size = 4 # 隐层层维度
8 num_layers = 2 # 层数
9
10 lstm = nn.LSTM(input_size, hidden_size, num_layers)
11
12
13 # Input:
14
15 # input of shape (sep_len, bath, input_size)
16 # h_t-1 of shape (num_directions * num_layers, bath, hidden_size)
17 # c_t-1 for shape (num_directions * num_layers, bath, hidden_size)
18
19 # Output:
20 # output of shape (sep_len, bath, num_directions * hidden_size)
21 # h_t-1 of shape (num_directions * num_layers, bath, hidden_size)
22 # c_t-1 for shape (num_directions * num_layers, bath, hidden_size)
23
24 # two ways
25 Input = torch.randn(4, 3, 2)
26 h = torch.randn(2, 3, 4)
27 c = torch.randn(2, 3, 4)
28 output = None
29
30 # first
31 h1 = h
32 c1 = c
33 for it in Input:
34 output, (h1, c1) = lstm(it.view(1, 3, -1), (h1, c1))
35 print((output == h1[-1]).all().item())
36 print(output)
37
38 # second
39 output1, (h, c) = lstm(Input,(h, c))
40 print(output1[-1])
41 # print(output1[-1] == output) 精度的问题