因为nn.GRU还有nn.LSTM的输出是两个元素,直接加到nn.Sequential中会报错,因此需要借助一个元素选择的小组件 SelectItem 来挑选
class SelectItem(nn.Module):
def __init__(self, item_index):
super(SelectItem, self).__init__()
self._name = 'selectitem'
self.item_index = item_index
def forward(self, inputs):
return inputs[self.item_index]
SelectItem 可以用于到Sequential中选择隐含状态:
net = nn.Sequential(
nn.GRU(dim_in, dim_out, batch_first=True),
SelectItem(1),
nn.Dropout(0.2),
)