import torch
input_size=4
hidden_size=4
batch_size=1
# 准备数据
idx2char=['e','h','l','o']
x_data=[1,0,2,2,3] # hello
y_data=[3,1,2,3,2] # ohlol
one_hot_lookup=[[1,0,0,0],
[0,1,0,0],
[0,0,1,0],
[0,0,0,1]] #分别对应0,1,2,3项
x_one_hot=[one_hot_lookup[x] for x in x_data] # 组成序列张量
print('x_one_hot:',x_one_hot)
# 构造输入序列和标签
inputs=torch.Tensor(x_one_hot).view(-1,batch_size,input_size)
labels=torch.LongTensor(y_data).view(-1,1)
# design model
class Model(torch.nn.Module):
def __init__(self,input_size,hidden_size,batch_size):
super(Model, self).__init__()
self.batch_size=batch_size
self.input_size=input_size
self.hidden_size=hidden_size
self.rnncell=torch.nn.RNNCell(input_size=self.input_size,
hidden_size=self.hidden_size)
def forward(self,input,hidden):
hidden=self.rnncell(input,hidden)
return hidden
def init_hidden(self):
return torch.zeros(self.batch_size,self.hidden_size)
net=Model(input_size,hidden_size,batch_size)
# loss and optimizer
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(),lr=0.1)
# train cycle
for epoch in range(15):
loss=0
optimizer.zero_grad()
hidden=net.init_hidden()
print('Predicted String:',end='')
for input ,lable in zip(inputs,labels):
hidden=net(input,hidden)
loss+=criterion(hidden,lable)
_, idx=hidden.max(dim=1)
print(idx2char[idx.item()],end='')
loss.backward()
optimizer.step()
print(',Epoch [%d/15] loss=%.4f' % (epoch+1,loss.item()))
输出结果:
x_one_hot: [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
Predicted String:eeeee,Epoch [1/15] loss=9.1647
Predicted String:eehoe,Epoch [2/15] loss=7.5995
Predicted String:elhol,Epoch [3/15] loss=6.7802
Predicted String:olool,Epoch [4/15] loss=6.1220
Predicted String:olool,Epoch [5/15] loss=5.5625
Predicted String:ololl,Epoch [6/15] loss=5.1270
Predicted String:ololl,Epoch [7/15] loss=4.8060
Predicted String:ololl,Epoch [8/15] loss=4.5607
Predicted String:oholl,Epoch [9/15] loss=4.3423
Predicted String:oholl,Epoch [10/15] loss=4.1480
Predicted String:oholl,Epoch [11/15] loss=3.9697
Predicted String:oholl,Epoch [12/15] loss=3.8007
Predicted String:oholl,Epoch [13/15] loss=3.6583
Predicted String:oholl,Epoch [14/15] loss=3.5437
Predicted String:oholl,Epoch [15/15] loss=3.4412