一.引言
函数式 API 可用于构建具有多个输入的模型,通常情况下,模型会在某一时刻用一个可以组合多个张量的层将不同输入得到的结果进行组合,组合方式可以是相加,连接等等,这其中常用的为 keras.layers.add, keras.layers.concatente 等。
二.多输入模型
1.模型结构
典型的问答模型有两个输入:一个自然语言描述的问题和一个文本片段用于提供回答的相关信息,最后生成一个回答,这里回答只包含一个词,可以通过 softmax 得到。
这里分别对参考文本 text 和 问题信息 question 做一个 Embedding 层,随后通过 LSTM 处理,最后通过 concatenate 结合在一起,通过最后的 softmax 层得到预测结果。
2.模型构建
text_vocabulary_size = 10000
question_vocabulary_szie = 10000
answer_vocabulary_size = 500
# 文本输入是可变长度的整数序列
text_input = Input(shape=(None, ), dtype='int32', name='text')
# 将输入嵌入64维的embedding
embedding_text = layers.Embedding(text_vocabulary_size, 64)(text_input)
# LSTM将向量编码转为单个向量
encoded_text = layers.LSTM(32)(embedding_text)
question_input = Input(shape=(None,),
dtype='int32',
name='question')
embedding_question = layers.Embedding(question_vocabulary_szie, 32)(question_input)
encoded_question = layers.LSTM(16)(embedding_question)
concatenated = layers.concatenate([encoded_text, encoded_question], axis=-1)
answer = layers.Dense(answer_vocabulary_size, activation='softmax')(concatenated)
model = Model([text_input, question_input], answer)
model.summary()
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
两个输入分别为 text,question,通过 Embedding + LSTM 得到向量,通过 concatenate 连接在一起,最终输出到 500 个输出即回答上。
3.模型训练
text 与 question 的词库量都限定为 10000,每条文本限定最大长度为 100,answer 为一个单词,loss选择为交叉熵,所以需转换至 answer 维度的 one-hot 形式,结尾 fit 处可以选择列表型按顺序输入输入层,也可以通过字典的形式与 Layer 名称对应。这里通过
Numpy 的随机函数模拟了向量化的文本,如果有实际需求,可以参考 TensorFlow-Keras 9.基础文本处理 。
num_samples = 1000
max_length = 100
text = np.random.randint(1, text_vocabulary_size,
size=(num_samples, max_length))
question = np.random.randint(1, question_vocabulary_szie,
size=(num_samples, max_length))
answers = np.random.randint(answer_vocabulary_size, size=(num_samples))
answers = utils.to_categorical(answers, answer_vocabulary_size)
# model.fit([text, question], answers, epochs=10, batch_size=128)
model.fit({'text': text, 'question': question}, answers, epochs=10, batch_size=128)
Epoch 1/10
8/8 [==============================] - 0s 22ms/step - loss: 6.2146 - accuracy: 0.0050
...
Epoch 9/10
8/8 [==============================] - 0s 21ms/step - loss: 5.6578 - accuracy: 0.0094
Epoch 10/10
8/8 [==============================] - 0s 22ms/step - loss: 5.6490 - accuracy: 0.0094