背景
结果前面文章中对文本处理、模型构建及训练等内容,本文主要介绍训练完成之后,如何利用模型进行生成文本?以及如何衡量模型的性能等。
核心内容
为尽快使baseline完整,本文先采用两种常见的解码算法:Greedy Decode
和Beam Serach
进行解码,因此后续文中实现也主要围绕这两个内容。训练过程和预测过程代码结构基本差不多,主要在predict.py
文件中。
模型加载
首先,重新加载已训练好的tensorflow模型,样例代码如下:
checkpoint = tf.train.Checkpoint(Seq2Seq=model)
checkpoint_manager = tf.train.CheckpointManager(checkpoint, seq2seq_checkpoint_dir, max_to_keep=5)
checkpoint.restore(checkpoint_manager.latest_checkpoint)
# checkpoint.restore('../../data/checkpoints/training_checkpoints_seq2seq/ckpt-6')
if checkpoint_manager.latest_checkpoint:
print("Restored from {}".format(checkpoint_manager.latest_checkpoint))
else:
print("Initializing from scratch.")
tensorflow中,指定latest_checkpoint
方法,可以自动加载最新训练保存的模型。在加载模型后,还需要定义解码算法,本文实现两种解码算法:贪心搜索和beam search,样例代码如下:
解码算法
Greedy Search
def greedy_decode(model, data_X, batch_size, vocab, params):
# 存储结果
results = []
# 样本数量
sample_size = len(data_X)
# batch 操作轮数 math.ceil向上取整 小数 +1
# 因为最后一个batch可能不足一个batch size 大小 ,但是依然需要计算
steps_epoch = math.ceil(sample_size / batch_size)
# [0,steps_epoch)
for i in tqdm(range(steps_epoch)):
batch_data = data_X[i * batch_size:(i + 1) * batch_size]
results += batch_greedy_decode(model, batch_data, vocab, params)
return results
Beam Search
b = beam_test_batch_generator(params["beam_size"])
results = []
for batch in b:
best_hyp = beam_decode(model, batch, vocab, params)
results.append(best_hyp.abstract)
# get_rouge(results) # 模型生成结果衡量,后续展开
beam search
代码分为两部分,数据加载和模型解码。
def beam_decode(model, batch_data, vocab, params):
# 初始化mask
start_index = vocab.STOP_DECODING_INDEX
stop_index = vocab.STOP_DECODING_INDEX
unk_index = vocab.UNKNOWN_TOKEN_INDEX
batch_size = params['batch_size']
# 单步decoder
def decoder_one_step(enc_output, dec_input, dec_hidden):
final_pred, dec_hidden, attention_weights = model.decoder(dec_input, dec_hidden, enc_output)
# 取top K个index及其对应的概率
top_k_probs, top_k_idx = tf.nn.top_k(tf.squeeze(final_pred), k=params['beam_size'] * 2)
# 重新计算概率分布
top_k_log_probs = tf.math.log(top_k_probs)
results = {
'dec_hidden': dec_hidden,
'attention_weights': attention_weights,
'top_k_idx': top_k_idx,
'top_k_log_probs': top_k_log_probs
}
return results
# 测试数据的输入
enc_input = batch_data
init_enc_hidden = model.encoder.initialize_hidden_state()
# 计算encoder的输出
enc_output, enc_hidden = model.encoder(enc_input, init_enc_hidden)
hyps_batch = [Hypothesis(tokens=[start_index],
log_probs=[0.],
hidden=enc_hidden[0],
attn_dists=[]) for _ in range(batch_size)]
# 初始化结果集合
results = []
steps = 0 # 遍历步数
# 当长度不够或者结果还不够时,继续搜索
while steps < params['max_dec_len'] and len(results) < params['beam_size']:
# 获取最新待使用的token
latest_tokens = [hyps.latest_token for hyps in hyps_batch]
# 替换掉oov token为unk token
latest_tokens = [token if token in vocab.index2word else unk_index for token in latest_tokens]
# 获取隐变量
hiddens = [hyps.hidden for hyps in hyps_batch]
dec_input = tf.expand_dims(latest_tokens, axis=1)
dec_hidden = tf.stack(hiddens, axis=0)
# 单步运行decoder
decoder_results = decoder_one_step(enc_output, dec_input, dec_hidden)
dec_hidden = decoder_results['dec_hidden']
attention_weights = decoder_results['attention_weights']
top_k_log_probs = decoder_results['top_k_log_probs']
top_k_idx = decoder_results['top_k_idx']
# 现阶段全部可能的情况
all_hyps = []
# 原有的所有可能情况
num_ori_hyps = 1 if steps == 0 else len(hyps_batch)
# 便利添加所有可能的结果
for i in range(num_ori_hyps):
hyps, new_hidden, attn_dist = hyps_batch[i], dec_hidden[i], attention_weights[i]
for j in range(params['beam_size'] * 2):
new_hyps = hyps.extend(
token=top_k_idx[i, j].numpy(),
log_prob=top_k_log_probs[i, j],
hidden=new_hidden,
attn_dist=attn_dist
)
all_hyps.append(new_hyps)
# 重置
hyps_batch = []
sorted_hyps = sorted(all_hyps, key=lambda h: h.ave_log_prob, reverse=True)
# 筛选
for h in sorted_hyps:
if h.latest_token == stop_index:
# 长度符合预测,遇到居委,添加到结果集
if steps >= params['min_dec_steps']:
h.tokens = h.tokens[1: -1]
results.append(h)
else:
hyps.append(h)
if len(hyps) == params['beam_size'] or len(results) == params['beam_size']:
break
steps += 1
if len(results) == 0:
results = hyps
hyps_sorted = sorted(results, key=lambda h: h.ave_log_prob, reverse=True)
print_top_k(hyps_sorted, 3, vocab, batch_data)
best_hyp = hyps_sorted[0]
best_hyp.abstract = ' '.join([vocab.index_to_word(index) for index in best_hyp.tokens])
return best_hyp
def batch_greedy_decode(model, batch_data, vocab, params):
# 判断输入长度
batch_size = len(batch_data)
# 存储预测结果
predictions = [''] * batch_size
inputs = tf.convert_to_tensor(batch_data)
# 0. 初始化隐层输入
init_hidden = tf.zeros(shape=(batch_size, params['enc_units']))
# 1. 构造encoder
enc_output, enc_hidden = model.encoder(inputs, init_hidden)
# 2. 复制到解码器
dec_hidden = enc_hidden
# 3. <START> * batch_size
dec_input = tf.expand_dims([vocab.word_to_index(vocab.START_DECODING)] * batch_size, 1)
# 4. 解码
for t in range(params['max_dec_len']):
# 4.0. 预测
predictions, dec_hidden, attention_weights = model.decoder(dec_input, dec_hidden, enc_output)
# 4.1. 取预测结果,概率最大值所对应的index
predictions_idx = tf.argmax(predictions, axis=1).numpy() # 最大值所对应的角标
# 4.2. 根据index,取相应的词,存放到列表
for index, predict_idx in enumerate(predictions_idx):
predictions[index] += vocab.index_to_word(predict_idx) + ' '
# 4.3. 继续下一个词的预测(用上一步预测的结果)
dec_input = tf.expand_dims(predictions_idx)
# 5. 解码结果处理
results = []
for prediction in predictions:
prediction = prediction.strip()
if vocab.STOP_DECODING in prediction:
prediction = prediction[:prediction.index(vocab.STOP_DECODING)]
results.append(prediction)
return results
class Hypothesis:
def __init__(self, tokens, log_probs, hidden, attn_dists):
self.tokens = tokens
self.log_probs = log_probs
self.hidden = hidden
self.attn_dists = attn_dists
self.abstract = ''
def extend(self, token, log_prob, hidden, attn_dist):
return Hypothesis(
tokens=self.tokens + [token],
log_probs=self.log_probs + [log_prob],
hidden=hidden,
attn_dists=self.attn_dists + [attn_dist]
)
@property
def latest_token(self):
return self.tokens[-1]
@property
def total_log_prob(self):
return sum(self.log_probs)
@property
def avg_log_prob(self):
return self.total_log_prob / len(self.tokens)
测试结果衡量
对测试结果衡量,主要采取的时Rouge分数。样例代码如下:
def get_rouge(results):
# 读取结果
seg_test_report = pd.read_csv(test_seg_path, header=None).iloc[:, 5].tolist()
seg_test_report = [' '.join(str(token) for token in str(line).split()) for line in seg_test_report]
rouge_scores = Rouge().get_scores(results, seg_test_report, avg=True)
print_rouge = json.dumps(rouge_scores, indent=2)
with open(os.path.join(os.path.dirname(test_seg_path), 'results.csv'), 'w', encoding='utf8') as f:
json.dump(list(zip(results, seg_test_report)), f, indent=2, ensure_ascii=False)
print('*' * 8 + ' rouge score ' + '*' * 8)
print(print_rouge)